From efc00ed6bb838dceaee7ad9469cc51d1500a365d Mon Sep 17 00:00:00 2001 From: Eric Schmidt Date: Fri, 12 Feb 2021 14:42:13 -0800 Subject: [PATCH] feat: updates python-aiplatform to v1 (#212) * feat: updates python-aiplatform to v1 * fix: adjusts import statement * chore: adds v1 as default aiplatform * fix: more changes to gapic/__init__.py * fix: reverted explainability samples to v1beta1 * fix: sample failures due to v1beta1 features --- .coveragerc | 18 + .flake8 | 1 + .github/header-checker-lint.yml | 15 + .kokoro/build.sh | 16 +- .kokoro/docs/docs-presubmit.cfg | 11 + .pre-commit-config.yaml | 17 + .trampolinerc | 1 + CONTRIBUTING.rst | 21 +- LICENSE | 7 +- docs/_static/custom.css | 7 +- docs/aiplatform_v1/dataset_service.rst | 11 + docs/aiplatform_v1/endpoint_service.rst | 11 + docs/aiplatform_v1/job_service.rst | 11 + docs/aiplatform_v1/migration_service.rst | 11 + docs/aiplatform_v1/model_service.rst | 11 + docs/aiplatform_v1/pipeline_service.rst | 11 + docs/aiplatform_v1/prediction_service.rst | 6 + docs/aiplatform_v1/services.rst | 13 + .../aiplatform_v1/specialist_pool_service.rst | 11 + docs/aiplatform_v1/types.rst | 7 + docs/aiplatform_v1beta1/dataset_service.rst | 11 + docs/aiplatform_v1beta1/endpoint_service.rst | 11 + docs/aiplatform_v1beta1/job_service.rst | 11 + docs/aiplatform_v1beta1/migration_service.rst | 11 + docs/aiplatform_v1beta1/model_service.rst | 11 + docs/aiplatform_v1beta1/pipeline_service.rst | 11 + .../aiplatform_v1beta1/prediction_service.rst | 6 + docs/aiplatform_v1beta1/services.rst | 34 +- .../specialist_pool_service.rst | 11 + docs/aiplatform_v1beta1/types.rst | 1 + docs/definition_v1/types.rst | 7 + docs/definition_v1beta1/types.rst | 1 + docs/instance_v1/types.rst | 7 + docs/instance_v1beta1/types.rst | 1 + docs/params_v1/types.rst | 7 + docs/params_v1beta1/types.rst | 1 + docs/prediction_v1/types.rst | 7 + docs/prediction_v1beta1/types.rst | 1 + google/cloud/aiplatform/gapic/__init__.py | 4 +- .../v1/schema/predict/instance/__init__.py | 56 + .../v1/schema/predict/instance/py.typed | 2 + .../v1/schema/predict/instance_v1/__init__.py | 39 + .../v1/schema/predict/instance_v1/py.typed | 2 + .../predict/instance_v1/types/__init__.py | 38 + .../instance_v1/types/image_classification.py | 51 + .../types/image_object_detection.py | 51 + .../instance_v1/types/image_segmentation.py | 45 + .../instance_v1/types/text_classification.py | 44 + .../instance_v1/types/text_extraction.py | 55 + .../instance_v1/types/text_sentiment.py | 44 + .../types/video_action_recognition.py | 64 + .../instance_v1/types/video_classification.py | 64 + .../types/video_object_tracking.py | 64 + .../v1/schema/predict/params/__init__.py | 44 + .../v1/schema/predict/params/py.typed | 2 + .../v1/schema/predict/params_v1/__init__.py | 33 + .../v1/schema/predict/params_v1/py.typed | 2 + .../predict/params_v1/types/__init__.py | 32 + .../params_v1/types/image_classification.py | 47 + .../params_v1/types/image_object_detection.py | 48 + .../params_v1/types/image_segmentation.py | 42 + .../types/video_action_recognition.py | 48 + .../params_v1/types/video_classification.py | 85 + .../params_v1/types/video_object_tracking.py | 54 + .../v1/schema/predict/prediction/__init__.py | 60 + .../v1/schema/predict/prediction/py.typed | 2 + .../schema/predict/prediction_v1/__init__.py | 41 + .../v1/schema/predict/prediction_v1/py.typed | 2 + .../predict/prediction_v1/types/__init__.py | 40 + .../prediction_v1/types/classification.py | 51 + .../types/image_object_detection.py | 64 + .../prediction_v1/types/image_segmentation.py | 57 + .../types/tabular_classification.py | 47 + .../prediction_v1/types/tabular_regression.py | 46 + .../prediction_v1/types/text_extraction.py | 67 + .../prediction_v1/types/text_sentiment.py | 45 + .../types/video_action_recognition.py | 74 + .../types/video_classification.py | 90 + .../types/video_object_tracking.py | 115 + .../schema/trainingjob/definition/__init__.py | 120 + .../v1/schema/trainingjob/definition/py.typed | 2 + .../trainingjob/definition_v1/__init__.py | 71 + .../schema/trainingjob/definition_v1/py.typed | 2 + .../definition_v1/types/__init__.py | 90 + .../types/automl_image_classification.py | 143 + .../types/automl_image_object_detection.py | 128 + .../types/automl_image_segmentation.py | 122 + .../definition_v1/types/automl_tables.py | 447 ++ .../types/automl_text_classification.py | 52 + .../types/automl_text_extraction.py | 43 + .../types/automl_text_sentiment.py | 59 + .../types/automl_video_action_recognition.py | 58 + .../types/automl_video_classification.py | 59 + .../types/automl_video_object_tracking.py | 62 + .../export_evaluated_data_items_config.py | 53 + .../instance_v1beta1/types/__init__.py | 1 - .../predict/params_v1beta1/types/__init__.py | 1 - .../prediction_v1beta1/types/__init__.py | 1 - .../types/image_object_detection.py | 2 +- .../types/text_sentiment.py | 45 +- .../types/video_action_recognition.py | 6 +- .../types/video_classification.py | 6 +- .../types/video_object_tracking.py | 18 +- .../definition_v1beta1/types/__init__.py | 1 - .../types/automl_forecasting.py | 40 +- .../types/automl_image_classification.py | 8 +- .../types/automl_image_object_detection.py | 8 +- .../types/automl_image_segmentation.py | 8 +- .../definition_v1beta1/types/automl_tables.py | 34 +- .../types/automl_text_classification.py | 2 +- .../types/automl_text_extraction.py | 2 +- .../types/automl_text_sentiment.py | 2 +- .../types/automl_video_action_recognition.py | 4 +- .../types/automl_video_classification.py | 4 +- .../types/automl_video_object_tracking.py | 4 +- google/cloud/aiplatform_v1/__init__.py | 345 + google/cloud/aiplatform_v1/py.typed | 2 + .../cloud/aiplatform_v1/services/__init__.py | 16 + .../services/dataset_service/__init__.py | 24 + .../services/dataset_service/async_client.py | 1044 +++ .../services/dataset_service/client.py | 1309 ++++ .../services/dataset_service/pagers.py | 407 ++ .../dataset_service/transports/__init__.py | 35 + .../dataset_service/transports/base.py | 254 + .../dataset_service/transports/grpc.py | 539 ++ .../transports/grpc_asyncio.py | 554 ++ .../services/endpoint_service/__init__.py | 24 + .../services/endpoint_service/async_client.py | 841 +++ .../services/endpoint_service/client.py | 1064 +++ .../services/endpoint_service/pagers.py | 149 + .../endpoint_service/transports/__init__.py | 35 + .../endpoint_service/transports/base.py | 208 + .../endpoint_service/transports/grpc.py | 456 ++ .../transports/grpc_asyncio.py | 473 ++ .../services/job_service/__init__.py | 24 + .../services/job_service/async_client.py | 1874 +++++ .../services/job_service/client.py | 2204 ++++++ .../services/job_service/pagers.py | 542 ++ .../job_service/transports/__init__.py | 35 + .../services/job_service/transports/base.py | 434 ++ .../services/job_service/transports/grpc.py | 886 +++ .../job_service/transports/grpc_asyncio.py | 907 +++ .../services/migration_service/__init__.py | 24 + .../migration_service/async_client.py | 367 + .../services/migration_service/client.py | 654 ++ .../services/migration_service/pagers.py | 153 + .../migration_service/transports/__init__.py | 35 + .../migration_service/transports/base.py | 150 + .../migration_service/transports/grpc.py | 333 + .../transports/grpc_asyncio.py | 340 + .../services/model_service/__init__.py | 24 + .../services/model_service/async_client.py | 1031 +++ .../services/model_service/client.py | 1307 ++++ .../services/model_service/pagers.py | 411 ++ .../model_service/transports/__init__.py | 35 + .../services/model_service/transports/base.py | 266 + .../services/model_service/transports/grpc.py | 547 ++ .../model_service/transports/grpc_asyncio.py | 558 ++ .../services/pipeline_service/__init__.py | 24 + .../services/pipeline_service/async_client.py | 600 ++ .../services/pipeline_service/client.py | 835 +++ .../services/pipeline_service/pagers.py | 153 + .../pipeline_service/transports/__init__.py | 35 + .../pipeline_service/transports/base.py | 201 + .../pipeline_service/transports/grpc.py | 426 ++ .../transports/grpc_asyncio.py | 435 ++ .../services/prediction_service/__init__.py | 24 + .../prediction_service/async_client.py | 263 + .../services/prediction_service/client.py | 468 ++ .../prediction_service/transports/__init__.py | 35 + .../prediction_service/transports/base.py | 127 + .../prediction_service/transports/grpc.py | 280 + .../transports/grpc_asyncio.py | 285 + .../specialist_pool_service/__init__.py | 24 + .../specialist_pool_service/async_client.py | 639 ++ .../specialist_pool_service/client.py | 841 +++ .../specialist_pool_service/pagers.py | 153 + .../transports/__init__.py | 37 + .../transports/base.py | 194 + .../transports/grpc.py | 418 ++ .../transports/grpc_asyncio.py | 427 ++ google/cloud/aiplatform_v1/types/__init__.py | 361 + .../aiplatform_v1/types/accelerator_type.py | 38 + .../cloud/aiplatform_v1/types/annotation.py | 109 + .../aiplatform_v1/types/annotation_spec.py | 65 + .../types/batch_prediction_job.py | 344 + .../aiplatform_v1/types/completion_stats.py | 55 + .../cloud/aiplatform_v1/types/custom_job.py | 318 + google/cloud/aiplatform_v1/types/data_item.py | 84 + .../aiplatform_v1/types/data_labeling_job.py | 275 + google/cloud/aiplatform_v1/types/dataset.py | 185 + .../aiplatform_v1/types/dataset_service.py | 446 ++ .../aiplatform_v1/types/deployed_model_ref.py | 42 + .../aiplatform_v1/types/encryption_spec.py | 43 + google/cloud/aiplatform_v1/types/endpoint.py | 200 + .../aiplatform_v1/types/endpoint_service.py | 337 + google/cloud/aiplatform_v1/types/env_var.py | 48 + .../types/hyperparameter_tuning_job.py | 142 + google/cloud/aiplatform_v1/types/io.py | 120 + .../cloud/aiplatform_v1/types/job_service.py | 619 ++ google/cloud/aiplatform_v1/types/job_state.py | 39 + .../aiplatform_v1/types/machine_resources.py | 212 + .../types/manual_batch_tuning_parameters.py | 46 + .../types/migratable_resource.py | 179 + .../aiplatform_v1/types/migration_service.py | 376 + google/cloud/aiplatform_v1/types/model.py | 630 ++ .../aiplatform_v1/types/model_evaluation.py | 73 + .../types/model_evaluation_slice.py | 91 + .../aiplatform_v1/types/model_service.py | 487 ++ google/cloud/aiplatform_v1/types/operation.py | 73 + .../aiplatform_v1/types/pipeline_service.py | 175 + .../aiplatform_v1/types/pipeline_state.py | 39 + .../aiplatform_v1/types/prediction_service.py | 88 + .../aiplatform_v1/types/specialist_pool.py | 68 + .../types/specialist_pool_service.py | 205 + google/cloud/aiplatform_v1/types/study.py | 444 ++ .../aiplatform_v1/types/training_pipeline.py | 467 ++ .../types/user_action_reference.py | 54 + google/cloud/aiplatform_v1beta1/__init__.py | 6 + .../services/dataset_service/async_client.py | 98 +- .../services/dataset_service/client.py | 155 +- .../services/dataset_service/pagers.py | 48 +- .../dataset_service/transports/__init__.py | 1 - .../dataset_service/transports/base.py | 8 +- .../dataset_service/transports/grpc.py | 45 +- .../transports/grpc_asyncio.py | 46 +- .../services/endpoint_service/async_client.py | 90 +- .../services/endpoint_service/client.py | 143 +- .../services/endpoint_service/pagers.py | 16 +- .../endpoint_service/transports/__init__.py | 1 - .../endpoint_service/transports/base.py | 8 +- .../endpoint_service/transports/grpc.py | 45 +- .../transports/grpc_asyncio.py | 46 +- .../services/job_service/async_client.py | 221 +- .../services/job_service/client.py | 314 +- .../services/job_service/pagers.py | 64 +- .../job_service/transports/__init__.py | 1 - .../services/job_service/transports/base.py | 8 +- .../services/job_service/transports/grpc.py | 45 +- .../job_service/transports/grpc_asyncio.py | 46 +- .../migration_service/async_client.py | 16 +- .../services/migration_service/client.py | 59 +- .../services/migration_service/pagers.py | 16 +- .../migration_service/transports/__init__.py | 1 - .../migration_service/transports/base.py | 8 +- .../migration_service/transports/grpc.py | 45 +- .../transports/grpc_asyncio.py | 46 +- .../services/model_service/async_client.py | 93 +- .../services/model_service/client.py | 150 +- .../services/model_service/pagers.py | 48 +- .../model_service/transports/__init__.py | 1 - .../services/model_service/transports/base.py | 8 +- .../services/model_service/transports/grpc.py | 45 +- .../model_service/transports/grpc_asyncio.py | 46 +- .../services/pipeline_service/async_client.py | 74 +- .../services/pipeline_service/client.py | 123 +- .../services/pipeline_service/pagers.py | 16 +- .../pipeline_service/transports/__init__.py | 1 - .../pipeline_service/transports/base.py | 8 +- .../pipeline_service/transports/grpc.py | 45 +- .../transports/grpc_asyncio.py | 46 +- .../prediction_service/async_client.py | 24 +- .../services/prediction_service/client.py | 69 +- .../prediction_service/transports/__init__.py | 1 - .../prediction_service/transports/base.py | 8 +- .../prediction_service/transports/grpc.py | 36 +- .../transports/grpc_asyncio.py | 39 +- .../specialist_pool_service/async_client.py | 92 +- .../specialist_pool_service/client.py | 139 +- .../specialist_pool_service/pagers.py | 16 +- .../transports/__init__.py | 1 - .../transports/base.py | 8 +- .../transports/grpc.py | 45 +- .../transports/grpc_asyncio.py | 46 +- .../aiplatform_v1beta1/types/__init__.py | 10 +- .../aiplatform_v1beta1/types/annotation.py | 10 +- .../types/annotation_spec.py | 4 +- .../types/batch_prediction_job.py | 98 +- .../aiplatform_v1beta1/types/custom_job.py | 85 +- .../aiplatform_v1beta1/types/data_item.py | 8 +- .../types/data_labeling_job.py | 40 +- .../cloud/aiplatform_v1beta1/types/dataset.py | 24 +- .../types/dataset_service.py | 54 +- .../types/encryption_spec.py | 43 + .../aiplatform_v1beta1/types/endpoint.py | 30 +- .../types/endpoint_service.py | 44 +- .../cloud/aiplatform_v1beta1/types/env_var.py | 16 +- .../aiplatform_v1beta1/types/explanation.py | 147 +- .../types/explanation_metadata.py | 30 +- .../types/hyperparameter_tuning_job.py | 31 +- google/cloud/aiplatform_v1beta1/types/io.py | 13 +- .../aiplatform_v1beta1/types/job_service.py | 24 +- .../types/machine_resources.py | 59 +- .../types/migratable_resource.py | 14 +- .../types/migration_service.py | 76 +- .../cloud/aiplatform_v1beta1/types/model.py | 75 +- .../types/model_evaluation.py | 36 +- .../types/model_evaluation_slice.py | 6 +- .../aiplatform_v1beta1/types/model_service.py | 71 +- .../aiplatform_v1beta1/types/operation.py | 8 +- .../types/pipeline_service.py | 6 +- .../types/prediction_service.py | 29 +- .../types/specialist_pool_service.py | 14 +- .../cloud/aiplatform_v1beta1/types/study.py | 83 +- .../types/training_pipeline.py | 112 +- noxfile.py | 12 + ...te_batch_prediction_job_bigquery_sample.py | 4 +- ...ediction_job_tabular_forecasting_sample.py | 4 +- samples/snippets/explain_tabular_sample.py | 4 +- ...ort_model_tabular_classification_sample.py | 4 +- ..._explain_image_managed_container_sample.py | 16 +- ...xplain_tabular_managed_container_sample.py | 16 +- synth.metadata | 23 +- synth.py | 114 +- tests/unit/gapic/aiplatform_v1/__init__.py | 1 + .../aiplatform_v1/test_dataset_service.py | 3553 ++++++++++ .../aiplatform_v1/test_endpoint_service.py | 2634 +++++++ .../gapic/aiplatform_v1/test_job_service.py | 6161 +++++++++++++++++ .../aiplatform_v1/test_migration_service.py | 1775 +++++ .../gapic/aiplatform_v1/test_model_service.py | 3729 ++++++++++ .../aiplatform_v1/test_pipeline_service.py | 2288 ++++++ .../test_specialist_pool_service.py | 2325 +++++++ .../test_dataset_service.py | 219 +- .../test_endpoint_service.py | 219 +- .../aiplatform_v1beta1/test_job_service.py | 269 +- .../test_migration_service.py | 218 +- .../aiplatform_v1beta1/test_model_service.py | 221 +- .../test_pipeline_service.py | 219 +- .../test_prediction_service.py | 218 +- .../test_specialist_pool_service.py | 220 +- 330 files changed, 65473 insertions(+), 2524 deletions(-) create mode 100644 .coveragerc create mode 100644 .github/header-checker-lint.yml create mode 100644 .pre-commit-config.yaml create mode 100644 docs/aiplatform_v1/dataset_service.rst create mode 100644 docs/aiplatform_v1/endpoint_service.rst create mode 100644 docs/aiplatform_v1/job_service.rst create mode 100644 docs/aiplatform_v1/migration_service.rst create mode 100644 docs/aiplatform_v1/model_service.rst create mode 100644 docs/aiplatform_v1/pipeline_service.rst create mode 100644 docs/aiplatform_v1/prediction_service.rst create mode 100644 docs/aiplatform_v1/services.rst create mode 100644 docs/aiplatform_v1/specialist_pool_service.rst create mode 100644 docs/aiplatform_v1/types.rst create mode 100644 docs/aiplatform_v1beta1/dataset_service.rst create mode 100644 docs/aiplatform_v1beta1/endpoint_service.rst create mode 100644 docs/aiplatform_v1beta1/job_service.rst create mode 100644 docs/aiplatform_v1beta1/migration_service.rst create mode 100644 docs/aiplatform_v1beta1/model_service.rst create mode 100644 docs/aiplatform_v1beta1/pipeline_service.rst create mode 100644 docs/aiplatform_v1beta1/prediction_service.rst create mode 100644 docs/aiplatform_v1beta1/specialist_pool_service.rst create mode 100644 docs/definition_v1/types.rst create mode 100644 docs/instance_v1/types.rst create mode 100644 docs/params_v1/types.rst create mode 100644 docs/prediction_v1/types.rst create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/py.typed create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py create mode 100644 google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py create mode 100644 google/cloud/aiplatform_v1/__init__.py create mode 100644 google/cloud/aiplatform_v1/py.typed create mode 100644 google/cloud/aiplatform_v1/services/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/client.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py create mode 100644 google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py create mode 100644 google/cloud/aiplatform_v1/types/__init__.py create mode 100644 google/cloud/aiplatform_v1/types/accelerator_type.py create mode 100644 google/cloud/aiplatform_v1/types/annotation.py create mode 100644 google/cloud/aiplatform_v1/types/annotation_spec.py create mode 100644 google/cloud/aiplatform_v1/types/batch_prediction_job.py create mode 100644 google/cloud/aiplatform_v1/types/completion_stats.py create mode 100644 google/cloud/aiplatform_v1/types/custom_job.py create mode 100644 google/cloud/aiplatform_v1/types/data_item.py create mode 100644 google/cloud/aiplatform_v1/types/data_labeling_job.py create mode 100644 google/cloud/aiplatform_v1/types/dataset.py create mode 100644 google/cloud/aiplatform_v1/types/dataset_service.py create mode 100644 google/cloud/aiplatform_v1/types/deployed_model_ref.py create mode 100644 google/cloud/aiplatform_v1/types/encryption_spec.py create mode 100644 google/cloud/aiplatform_v1/types/endpoint.py create mode 100644 google/cloud/aiplatform_v1/types/endpoint_service.py create mode 100644 google/cloud/aiplatform_v1/types/env_var.py create mode 100644 google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py create mode 100644 google/cloud/aiplatform_v1/types/io.py create mode 100644 google/cloud/aiplatform_v1/types/job_service.py create mode 100644 google/cloud/aiplatform_v1/types/job_state.py create mode 100644 google/cloud/aiplatform_v1/types/machine_resources.py create mode 100644 google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py create mode 100644 google/cloud/aiplatform_v1/types/migratable_resource.py create mode 100644 google/cloud/aiplatform_v1/types/migration_service.py create mode 100644 google/cloud/aiplatform_v1/types/model.py create mode 100644 google/cloud/aiplatform_v1/types/model_evaluation.py create mode 100644 google/cloud/aiplatform_v1/types/model_evaluation_slice.py create mode 100644 google/cloud/aiplatform_v1/types/model_service.py create mode 100644 google/cloud/aiplatform_v1/types/operation.py create mode 100644 google/cloud/aiplatform_v1/types/pipeline_service.py create mode 100644 google/cloud/aiplatform_v1/types/pipeline_state.py create mode 100644 google/cloud/aiplatform_v1/types/prediction_service.py create mode 100644 google/cloud/aiplatform_v1/types/specialist_pool.py create mode 100644 google/cloud/aiplatform_v1/types/specialist_pool_service.py create mode 100644 google/cloud/aiplatform_v1/types/study.py create mode 100644 google/cloud/aiplatform_v1/types/training_pipeline.py create mode 100644 google/cloud/aiplatform_v1/types/user_action_reference.py create mode 100644 google/cloud/aiplatform_v1beta1/types/encryption_spec.py create mode 100644 tests/unit/gapic/aiplatform_v1/__init__.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_dataset_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_endpoint_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_job_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_migration_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_model_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_pipeline_service.py create mode 100644 tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..5b3f287a0f --- /dev/null +++ b/.coveragerc @@ -0,0 +1,18 @@ +[run] +branch = True + +[report] +fail_under = 100 +show_missing = True +omit = + google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py +exclude_lines = + # Re-enable the standard pragma + pragma: NO COVER + # Ignore debug-only repr + def __repr__ + # Ignore pkg_resources exceptions. + # This is added at the module level as a safeguard for if someone + # generates the code and tries to run it without pip installing. This + # makes it virtually impossible to test properly. + except pkg_resources.DistributionNotFound diff --git a/.flake8 b/.flake8 index ed9316381c..29227d4cf4 100644 --- a/.flake8 +++ b/.flake8 @@ -26,6 +26,7 @@ exclude = *_pb2.py # Standard linting exemptions. + **/.nox/** __pycache__, .git, *.pyc, diff --git a/.github/header-checker-lint.yml b/.github/header-checker-lint.yml new file mode 100644 index 0000000000..fc281c05bd --- /dev/null +++ b/.github/header-checker-lint.yml @@ -0,0 +1,15 @@ +{"allowedCopyrightHolders": ["Google LLC"], + "allowedLicenses": ["Apache-2.0", "MIT", "BSD-3"], + "ignoreFiles": ["**/requirements.txt", "**/requirements-test.txt"], + "sourceFileExtensions": [ + "ts", + "js", + "java", + "sh", + "Dockerfile", + "yaml", + "py", + "html", + "txt" + ] +} \ No newline at end of file diff --git a/.kokoro/build.sh b/.kokoro/build.sh index 4c37ecad0c..ef4eb9c094 100755 --- a/.kokoro/build.sh +++ b/.kokoro/build.sh @@ -15,7 +15,11 @@ set -eo pipefail -cd github/python-aiplatform +if [[ -z "${PROJECT_ROOT:-}" ]]; then + PROJECT_ROOT="github/python-aiplatform" +fi + +cd "${PROJECT_ROOT}" # Disable buffering, so that the logs stream through. export PYTHONUNBUFFERED=1 @@ -30,16 +34,16 @@ export GOOGLE_APPLICATION_CREDENTIALS=${KOKORO_GFILE_DIR}/service-account.json export PROJECT_ID=$(cat "${KOKORO_GFILE_DIR}/project-id.json") # Remove old nox -python3.6 -m pip uninstall --yes --quiet nox-automation +python3 -m pip uninstall --yes --quiet nox-automation # Install nox -python3.6 -m pip install --upgrade --quiet nox -python3.6 -m nox --version +python3 -m pip install --upgrade --quiet nox +python3 -m nox --version # If NOX_SESSION is set, it only runs the specified session, # otherwise run all the sessions. if [[ -n "${NOX_SESSION:-}" ]]; then - python3.6 -m nox -s "${NOX_SESSION:-}" + python3 -m nox -s ${NOX_SESSION:-} else - python3.6 -m nox + python3 -m nox fi diff --git a/.kokoro/docs/docs-presubmit.cfg b/.kokoro/docs/docs-presubmit.cfg index 1118107829..85c4e08775 100644 --- a/.kokoro/docs/docs-presubmit.cfg +++ b/.kokoro/docs/docs-presubmit.cfg @@ -15,3 +15,14 @@ env_vars: { key: "TRAMPOLINE_IMAGE_UPLOAD" value: "false" } + +env_vars: { + key: "TRAMPOLINE_BUILD_FILE" + value: "github/python-aiplatform/.kokoro/build.sh" +} + +# Only run this nox session. +env_vars: { + key: "NOX_SESSION" + value: "docs docfx" +} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..a9024b15d7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml +- repo: https://github.com/psf/black + rev: 19.10b0 + hooks: + - id: black +- repo: https://gitlab.com/pycqa/flake8 + rev: 3.8.4 + hooks: + - id: flake8 diff --git a/.trampolinerc b/.trampolinerc index 995ee29111..383b6ec89f 100644 --- a/.trampolinerc +++ b/.trampolinerc @@ -24,6 +24,7 @@ required_envvars+=( pass_down_envvars+=( "STAGING_BUCKET" "V2_STAGING_BUCKET" + "NOX_SESSION" ) # Prevent unintentional override on the default image. diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index d206a2d88d..3cab430ce1 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -21,8 +21,8 @@ In order to add a feature: - The feature must be documented in both the API and narrative documentation. -- The feature must work fully on the following CPython versions: 2.7, - 3.5, 3.6, 3.7 and 3.8 on both UNIX and Windows. +- The feature must work fully on the following CPython versions: + 3.6, 3.7, 3.8 and 3.9 on both UNIX and Windows. - The feature must not add unnecessary dependencies (where "unnecessary" is of course subjective, but new dependencies should @@ -111,6 +111,16 @@ Coding Style should point to the official ``googleapis`` checkout and the the branch should be the main branch on that remote (``master``). +- This repository contains configuration for the + `pre-commit `__ tool, which automates checking + our linters during a commit. If you have it installed on your ``$PATH``, + you can enable enforcing those checks via: + +.. code-block:: bash + + $ pre-commit install + pre-commit installed at .git/hooks/pre-commit + Exceptions to PEP8: - Many unit tests use a helper method, ``_call_fut`` ("FUT" is short for @@ -192,25 +202,24 @@ Supported Python Versions We support: -- `Python 3.5`_ - `Python 3.6`_ - `Python 3.7`_ - `Python 3.8`_ +- `Python 3.9`_ -.. _Python 3.5: https://docs.python.org/3.5/ .. _Python 3.6: https://docs.python.org/3.6/ .. _Python 3.7: https://docs.python.org/3.7/ .. _Python 3.8: https://docs.python.org/3.8/ +.. _Python 3.9: https://docs.python.org/3.9/ Supported versions can be found in our ``noxfile.py`` `config`_. .. _config: https://github.com/googleapis/python-aiplatform/blob/master/noxfile.py -Python 2.7 support is deprecated. All code changes should maintain Python 2.7 compatibility until January 1, 2020. We also explicitly decided to support Python 3 beginning with version -3.5. Reasons for this include: +3.6. Reasons for this include: - Encouraging use of newest versions of Python 3 - Taking the lead of `prominent`_ open-source `projects`_ diff --git a/LICENSE b/LICENSE index a8ee855de2..d645695673 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,7 @@ - Apache License + + Apache License Version 2.0, January 2004 - https://www.apache.org/licenses/ + http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION @@ -192,7 +193,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/docs/_static/custom.css b/docs/_static/custom.css index 0abaf229fc..bcd37bbd3c 100644 --- a/docs/_static/custom.css +++ b/docs/_static/custom.css @@ -1,4 +1,9 @@ div#python2-eol { border-color: red; border-width: medium; -} \ No newline at end of file +} + +/* Ensure minimum width for 'Parameters' / 'Returns' column */ +dl.field-list > dt { + min-width: 100px +} diff --git a/docs/aiplatform_v1/dataset_service.rst b/docs/aiplatform_v1/dataset_service.rst new file mode 100644 index 0000000000..46694cf2c0 --- /dev/null +++ b/docs/aiplatform_v1/dataset_service.rst @@ -0,0 +1,11 @@ +DatasetService +-------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.dataset_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.dataset_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/endpoint_service.rst b/docs/aiplatform_v1/endpoint_service.rst new file mode 100644 index 0000000000..29d05c30b4 --- /dev/null +++ b/docs/aiplatform_v1/endpoint_service.rst @@ -0,0 +1,11 @@ +EndpointService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.endpoint_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.endpoint_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/job_service.rst b/docs/aiplatform_v1/job_service.rst new file mode 100644 index 0000000000..6bfd457244 --- /dev/null +++ b/docs/aiplatform_v1/job_service.rst @@ -0,0 +1,11 @@ +JobService +---------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.job_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.job_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/migration_service.rst b/docs/aiplatform_v1/migration_service.rst new file mode 100644 index 0000000000..f322a1b3bf --- /dev/null +++ b/docs/aiplatform_v1/migration_service.rst @@ -0,0 +1,11 @@ +MigrationService +---------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.migration_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.migration_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/model_service.rst b/docs/aiplatform_v1/model_service.rst new file mode 100644 index 0000000000..ca269a9ad2 --- /dev/null +++ b/docs/aiplatform_v1/model_service.rst @@ -0,0 +1,11 @@ +ModelService +------------------------------ + +.. automodule:: google.cloud.aiplatform_v1.services.model_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.model_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/pipeline_service.rst b/docs/aiplatform_v1/pipeline_service.rst new file mode 100644 index 0000000000..b718db39b4 --- /dev/null +++ b/docs/aiplatform_v1/pipeline_service.rst @@ -0,0 +1,11 @@ +PipelineService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.pipeline_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.pipeline_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/prediction_service.rst b/docs/aiplatform_v1/prediction_service.rst new file mode 100644 index 0000000000..fdda504879 --- /dev/null +++ b/docs/aiplatform_v1/prediction_service.rst @@ -0,0 +1,6 @@ +PredictionService +----------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.prediction_service + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/services.rst b/docs/aiplatform_v1/services.rst new file mode 100644 index 0000000000..fd5a8c9aa7 --- /dev/null +++ b/docs/aiplatform_v1/services.rst @@ -0,0 +1,13 @@ +Services for Google Cloud Aiplatform v1 API +=========================================== +.. toctree:: + :maxdepth: 2 + + dataset_service + endpoint_service + job_service + migration_service + model_service + pipeline_service + prediction_service + specialist_pool_service diff --git a/docs/aiplatform_v1/specialist_pool_service.rst b/docs/aiplatform_v1/specialist_pool_service.rst new file mode 100644 index 0000000000..37ac386b31 --- /dev/null +++ b/docs/aiplatform_v1/specialist_pool_service.rst @@ -0,0 +1,11 @@ +SpecialistPoolService +--------------------------------------- + +.. automodule:: google.cloud.aiplatform_v1.services.specialist_pool_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1.services.specialist_pool_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1/types.rst b/docs/aiplatform_v1/types.rst new file mode 100644 index 0000000000..ad4454843f --- /dev/null +++ b/docs/aiplatform_v1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Aiplatform v1 API +======================================== + +.. automodule:: google.cloud.aiplatform_v1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/aiplatform_v1beta1/dataset_service.rst b/docs/aiplatform_v1beta1/dataset_service.rst new file mode 100644 index 0000000000..ad3866e1e4 --- /dev/null +++ b/docs/aiplatform_v1beta1/dataset_service.rst @@ -0,0 +1,11 @@ +DatasetService +-------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/endpoint_service.rst b/docs/aiplatform_v1beta1/endpoint_service.rst new file mode 100644 index 0000000000..c5ce91ed19 --- /dev/null +++ b/docs/aiplatform_v1beta1/endpoint_service.rst @@ -0,0 +1,11 @@ +EndpointService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/job_service.rst b/docs/aiplatform_v1beta1/job_service.rst new file mode 100644 index 0000000000..eee169a096 --- /dev/null +++ b/docs/aiplatform_v1beta1/job_service.rst @@ -0,0 +1,11 @@ +JobService +---------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/migration_service.rst b/docs/aiplatform_v1beta1/migration_service.rst new file mode 100644 index 0000000000..42ff54c101 --- /dev/null +++ b/docs/aiplatform_v1beta1/migration_service.rst @@ -0,0 +1,11 @@ +MigrationService +---------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.migration_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.migration_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/model_service.rst b/docs/aiplatform_v1beta1/model_service.rst new file mode 100644 index 0000000000..0fc01a1bd6 --- /dev/null +++ b/docs/aiplatform_v1beta1/model_service.rst @@ -0,0 +1,11 @@ +ModelService +------------------------------ + +.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/pipeline_service.rst b/docs/aiplatform_v1beta1/pipeline_service.rst new file mode 100644 index 0000000000..465949eeb0 --- /dev/null +++ b/docs/aiplatform_v1beta1/pipeline_service.rst @@ -0,0 +1,11 @@ +PipelineService +--------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/prediction_service.rst b/docs/aiplatform_v1beta1/prediction_service.rst new file mode 100644 index 0000000000..03c1150df0 --- /dev/null +++ b/docs/aiplatform_v1beta1/prediction_service.rst @@ -0,0 +1,6 @@ +PredictionService +----------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.prediction_service + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/services.rst b/docs/aiplatform_v1beta1/services.rst index 664c9df0a8..dd8c8a41bc 100644 --- a/docs/aiplatform_v1beta1/services.rst +++ b/docs/aiplatform_v1beta1/services.rst @@ -1,27 +1,13 @@ Services for Google Cloud Aiplatform v1beta1 API ================================================ +.. toctree:: + :maxdepth: 2 -.. automodule:: google.cloud.aiplatform_v1beta1.services.dataset_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.endpoint_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.job_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.migration_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.model_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.pipeline_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.prediction_service - :members: - :inherited-members: -.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service - :members: - :inherited-members: + dataset_service + endpoint_service + job_service + migration_service + model_service + pipeline_service + prediction_service + specialist_pool_service diff --git a/docs/aiplatform_v1beta1/specialist_pool_service.rst b/docs/aiplatform_v1beta1/specialist_pool_service.rst new file mode 100644 index 0000000000..4d264dc256 --- /dev/null +++ b/docs/aiplatform_v1beta1/specialist_pool_service.rst @@ -0,0 +1,11 @@ +SpecialistPoolService +--------------------------------------- + +.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service + :members: + :inherited-members: + + +.. automodule:: google.cloud.aiplatform_v1beta1.services.specialist_pool_service.pagers + :members: + :inherited-members: diff --git a/docs/aiplatform_v1beta1/types.rst b/docs/aiplatform_v1beta1/types.rst index 19bab68ada..770675f8ea 100644 --- a/docs/aiplatform_v1beta1/types.rst +++ b/docs/aiplatform_v1beta1/types.rst @@ -3,4 +3,5 @@ Types for Google Cloud Aiplatform v1beta1 API .. automodule:: google.cloud.aiplatform_v1beta1.types :members: + :undoc-members: :show-inheritance: diff --git a/docs/definition_v1/types.rst b/docs/definition_v1/types.rst new file mode 100644 index 0000000000..a1df2bce25 --- /dev/null +++ b/docs/definition_v1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Aiplatform V1 Schema Trainingjob Definition v1 API +========================================================================= + +.. automodule:: google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/definition_v1beta1/types.rst b/docs/definition_v1beta1/types.rst index 3f351d03fc..f4fe7a5301 100644 --- a/docs/definition_v1beta1/types.rst +++ b/docs/definition_v1beta1/types.rst @@ -3,4 +3,5 @@ Types for Google Cloud Aiplatform V1beta1 Schema Trainingjob Definition v1beta1 .. automodule:: google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types :members: + :undoc-members: :show-inheritance: diff --git a/docs/instance_v1/types.rst b/docs/instance_v1/types.rst new file mode 100644 index 0000000000..564ab013ee --- /dev/null +++ b/docs/instance_v1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Instance v1 API +=================================================================== + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.instance_v1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/instance_v1beta1/types.rst b/docs/instance_v1beta1/types.rst index c52ae4800c..7caa088065 100644 --- a/docs/instance_v1beta1/types.rst +++ b/docs/instance_v1beta1/types.rst @@ -3,4 +3,5 @@ Types for Google Cloud Aiplatform V1beta1 Schema Predict Instance v1beta1 API .. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.instance_v1beta1.types :members: + :undoc-members: :show-inheritance: diff --git a/docs/params_v1/types.rst b/docs/params_v1/types.rst new file mode 100644 index 0000000000..956ef5224d --- /dev/null +++ b/docs/params_v1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Params v1 API +================================================================= + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.params_v1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/params_v1beta1/types.rst b/docs/params_v1beta1/types.rst index ce7a29cb01..722a1d8ba0 100644 --- a/docs/params_v1beta1/types.rst +++ b/docs/params_v1beta1/types.rst @@ -3,4 +3,5 @@ Types for Google Cloud Aiplatform V1beta1 Schema Predict Params v1beta1 API .. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.params_v1beta1.types :members: + :undoc-members: :show-inheritance: diff --git a/docs/prediction_v1/types.rst b/docs/prediction_v1/types.rst new file mode 100644 index 0000000000..a97faf34de --- /dev/null +++ b/docs/prediction_v1/types.rst @@ -0,0 +1,7 @@ +Types for Google Cloud Aiplatform V1 Schema Predict Prediction v1 API +===================================================================== + +.. automodule:: google.cloud.aiplatform.v1.schema.predict.prediction_v1.types + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/prediction_v1beta1/types.rst b/docs/prediction_v1beta1/types.rst index cdbe7f2842..b14182d6d7 100644 --- a/docs/prediction_v1beta1/types.rst +++ b/docs/prediction_v1beta1/types.rst @@ -3,4 +3,5 @@ Types for Google Cloud Aiplatform V1beta1 Schema Predict Prediction v1beta1 API .. automodule:: google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types :members: + :undoc-members: :show-inheritance: diff --git a/google/cloud/aiplatform/gapic/__init__.py b/google/cloud/aiplatform/gapic/__init__.py index 790ebeffdf..75a9bcd38f 100644 --- a/google/cloud/aiplatform/gapic/__init__.py +++ b/google/cloud/aiplatform/gapic/__init__.py @@ -16,9 +16,9 @@ # # The latest GAPIC version is exported to the google.cloud.aiplatform.gapic namespace. -from google.cloud.aiplatform_v1beta1 import * +from google.cloud.aiplatform_v1 import * from google.cloud.aiplatform.gapic import schema -from google.cloud import aiplatform_v1beta1 as v1beta1 +from google.cloud import aiplatform_v1 as v1 __all__ = () diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py new file mode 100644 index 0000000000..fb2668afb5 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance/__init__.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_classification import ( + ImageClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.image_segmentation import ( + ImageSegmentationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_classification import ( + TextClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_extraction import ( + TextExtractionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.text_sentiment import ( + TextSentimentPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_classification import ( + VideoClassificationPredictionInstance, +) +from google.cloud.aiplatform.v1.schema.predict.instance_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionInstance, +) + +__all__ = ( + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/py.typed b/google/cloud/aiplatform/v1/schema/predict/instance/py.typed new file mode 100644 index 0000000000..f70e7f605a --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-instance package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py new file mode 100644 index 0000000000..f6d9a128ad --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/__init__.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types.image_classification import ImageClassificationPredictionInstance +from .types.image_object_detection import ImageObjectDetectionPredictionInstance +from .types.image_segmentation import ImageSegmentationPredictionInstance +from .types.text_classification import TextClassificationPredictionInstance +from .types.text_extraction import TextExtractionPredictionInstance +from .types.text_sentiment import TextSentimentPredictionInstance +from .types.video_action_recognition import VideoActionRecognitionPredictionInstance +from .types.video_classification import VideoClassificationPredictionInstance +from .types.video_object_tracking import VideoObjectTrackingPredictionInstance + + +__all__ = ( + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", + "ImageClassificationPredictionInstance", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/py.typed b/google/cloud/aiplatform/v1/schema/predict/instance_v1/py.typed new file mode 100644 index 0000000000..f70e7f605a --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-instance package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py new file mode 100644 index 0000000000..041fe6cdb1 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/__init__.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .image_classification import ImageClassificationPredictionInstance +from .image_object_detection import ImageObjectDetectionPredictionInstance +from .image_segmentation import ImageSegmentationPredictionInstance +from .text_classification import TextClassificationPredictionInstance +from .text_extraction import TextExtractionPredictionInstance +from .text_sentiment import TextSentimentPredictionInstance +from .video_action_recognition import VideoActionRecognitionPredictionInstance +from .video_classification import VideoClassificationPredictionInstance +from .video_object_tracking import VideoObjectTrackingPredictionInstance + +__all__ = ( + "ImageClassificationPredictionInstance", + "ImageObjectDetectionPredictionInstance", + "ImageSegmentationPredictionInstance", + "TextClassificationPredictionInstance", + "TextExtractionPredictionInstance", + "TextSentimentPredictionInstance", + "VideoActionRecognitionPredictionInstance", + "VideoClassificationPredictionInstance", + "VideoObjectTrackingPredictionInstance", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py new file mode 100644 index 0000000000..b5fa9b4dbf --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_classification.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageClassificationPredictionInstance",}, +) + + +class ImageClassificationPredictionInstance(proto.Message): + r"""Prediction input format for Image Classification. + + Attributes: + content (str): + The image bytes or GCS URI to make the + prediction on. + mime_type (str): + The MIME type of the content of the image. + Only the images in below listed MIME types are + supported. - image/jpeg + - image/gif + - image/png + - image/webp + - image/bmp + - image/tiff + - image/vnd.microsoft.icon + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py new file mode 100644 index 0000000000..45752ce7e2 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_object_detection.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageObjectDetectionPredictionInstance",}, +) + + +class ImageObjectDetectionPredictionInstance(proto.Message): + r"""Prediction input format for Image Object Detection. + + Attributes: + content (str): + The image bytes or GCS URI to make the + prediction on. + mime_type (str): + The MIME type of the content of the image. + Only the images in below listed MIME types are + supported. - image/jpeg + - image/gif + - image/png + - image/webp + - image/bmp + - image/tiff + - image/vnd.microsoft.icon + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py new file mode 100644 index 0000000000..cb436d7029 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/image_segmentation.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"ImageSegmentationPredictionInstance",}, +) + + +class ImageSegmentationPredictionInstance(proto.Message): + r"""Prediction input format for Image Segmentation. + + Attributes: + content (str): + The image bytes to make the predictions on. + mime_type (str): + The MIME type of the content of the image. + Only the images in below listed MIME types are + supported. - image/jpeg + - image/png + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py new file mode 100644 index 0000000000..ceff5308b7 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_classification.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextClassificationPredictionInstance",}, +) + + +class TextClassificationPredictionInstance(proto.Message): + r"""Prediction input format for Text Classification. + + Attributes: + content (str): + The text snippet to make the predictions on. + mime_type (str): + The MIME type of the text snippet. The + supported MIME types are listed below. + - text/plain + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py new file mode 100644 index 0000000000..2e96216466 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_extraction.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextExtractionPredictionInstance",}, +) + + +class TextExtractionPredictionInstance(proto.Message): + r"""Prediction input format for Text Extraction. + + Attributes: + content (str): + The text snippet to make the predictions on. + mime_type (str): + The MIME type of the text snippet. The + supported MIME types are listed below. + - text/plain + key (str): + This field is only used for batch prediction. + If a key is provided, the batch prediction + result will by mapped to this key. If omitted, + then the batch prediction result will contain + the entire input instance. AI Platform will not + check if keys in the request are duplicates, so + it is up to the caller to ensure the keys are + unique. + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + key = proto.Field(proto.STRING, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py new file mode 100644 index 0000000000..37353ad806 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/text_sentiment.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"TextSentimentPredictionInstance",}, +) + + +class TextSentimentPredictionInstance(proto.Message): + r"""Prediction input format for Text Sentiment. + + Attributes: + content (str): + The text snippet to make the predictions on. + mime_type (str): + The MIME type of the text snippet. The + supported MIME types are listed below. + - text/plain + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py new file mode 100644 index 0000000000..6de5665312 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_action_recognition.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoActionRecognitionPredictionInstance",}, +) + + +class VideoActionRecognitionPredictionInstance(proto.Message): + r"""Prediction input format for Video Action Recognition. + + Attributes: + content (str): + The Google Cloud Storage location of the + video on which to perform the prediction. + mime_type (str): + The MIME type of the content of the video. + Only the following are supported: video/mp4 + video/avi video/quicktime + time_segment_start (str): + The beginning, inclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision. + time_segment_end (str): + The end, exclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + time_segment_start = proto.Field(proto.STRING, number=3) + + time_segment_end = proto.Field(proto.STRING, number=4) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py new file mode 100644 index 0000000000..ab7c0edfe1 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_classification.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoClassificationPredictionInstance",}, +) + + +class VideoClassificationPredictionInstance(proto.Message): + r"""Prediction input format for Video Classification. + + Attributes: + content (str): + The Google Cloud Storage location of the + video on which to perform the prediction. + mime_type (str): + The MIME type of the content of the video. + Only the following are supported: video/mp4 + video/avi video/quicktime + time_segment_start (str): + The beginning, inclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision. + time_segment_end (str): + The end, exclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + time_segment_start = proto.Field(proto.STRING, number=3) + + time_segment_end = proto.Field(proto.STRING, number=4) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py new file mode 100644 index 0000000000..f797f58f4e --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/types/video_object_tracking.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.instance", + manifest={"VideoObjectTrackingPredictionInstance",}, +) + + +class VideoObjectTrackingPredictionInstance(proto.Message): + r"""Prediction input format for Video Object Tracking. + + Attributes: + content (str): + The Google Cloud Storage location of the + video on which to perform the prediction. + mime_type (str): + The MIME type of the content of the video. + Only the following are supported: video/mp4 + video/avi video/quicktime + time_segment_start (str): + The beginning, inclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision. + time_segment_end (str): + The end, exclusive, of the video's time + segment on which to perform the prediction. + Expressed as a number of seconds as measured + from the start of the video, with "s" appended + at the end. Fractions are allowed, up to a + microsecond precision, and "inf" or "Infinity" + is allowed, which means the end of the video. + """ + + content = proto.Field(proto.STRING, number=1) + + mime_type = proto.Field(proto.STRING, number=2) + + time_segment_start = proto.Field(proto.STRING, number=3) + + time_segment_end = proto.Field(proto.STRING, number=4) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py new file mode 100644 index 0000000000..c046f4d7e5 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_classification import ( + ImageClassificationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.image_segmentation import ( + ImageSegmentationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_classification import ( + VideoClassificationPredictionParams, +) +from google.cloud.aiplatform.v1.schema.predict.params_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionParams, +) + +__all__ = ( + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/params/py.typed b/google/cloud/aiplatform/v1/schema/predict/params/py.typed new file mode 100644 index 0000000000..df96e61590 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-params package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py new file mode 100644 index 0000000000..79fb1c2097 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types.image_classification import ImageClassificationPredictionParams +from .types.image_object_detection import ImageObjectDetectionPredictionParams +from .types.image_segmentation import ImageSegmentationPredictionParams +from .types.video_action_recognition import VideoActionRecognitionPredictionParams +from .types.video_classification import VideoClassificationPredictionParams +from .types.video_object_tracking import VideoObjectTrackingPredictionParams + + +__all__ = ( + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", + "ImageClassificationPredictionParams", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/py.typed b/google/cloud/aiplatform/v1/schema/predict/params_v1/py.typed new file mode 100644 index 0000000000..df96e61590 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-params package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py new file mode 100644 index 0000000000..2f2c29bba5 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/__init__.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .image_classification import ImageClassificationPredictionParams +from .image_object_detection import ImageObjectDetectionPredictionParams +from .image_segmentation import ImageSegmentationPredictionParams +from .video_action_recognition import VideoActionRecognitionPredictionParams +from .video_classification import VideoClassificationPredictionParams +from .video_object_tracking import VideoObjectTrackingPredictionParams + +__all__ = ( + "ImageClassificationPredictionParams", + "ImageObjectDetectionPredictionParams", + "ImageSegmentationPredictionParams", + "VideoActionRecognitionPredictionParams", + "VideoClassificationPredictionParams", + "VideoObjectTrackingPredictionParams", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py new file mode 100644 index 0000000000..3a9efd0ea2 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_classification.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageClassificationPredictionParams",}, +) + + +class ImageClassificationPredictionParams(proto.Message): + r"""Prediction model parameters for Image Classification. + + Attributes: + confidence_threshold (float): + The Model only returns predictions with at + least this confidence score. Default value is + 0.0 + max_predictions (int): + The Model only returns up to that many top, + by confidence score, predictions per instance. + If this number is very high, the Model may + return fewer predictions. Default value is 10. + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + max_predictions = proto.Field(proto.INT32, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py new file mode 100644 index 0000000000..c37507a4e0 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_object_detection.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageObjectDetectionPredictionParams",}, +) + + +class ImageObjectDetectionPredictionParams(proto.Message): + r"""Prediction model parameters for Image Object Detection. + + Attributes: + confidence_threshold (float): + The Model only returns predictions with at + least this confidence score. Default value is + 0.0 + max_predictions (int): + The Model only returns up to that many top, + by confidence score, predictions per instance. + Note that number of returned predictions is also + limited by metadata's predictionsLimit. Default + value is 10. + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + max_predictions = proto.Field(proto.INT32, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py new file mode 100644 index 0000000000..108cff107b --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/image_segmentation.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"ImageSegmentationPredictionParams",}, +) + + +class ImageSegmentationPredictionParams(proto.Message): + r"""Prediction model parameters for Image Segmentation. + + Attributes: + confidence_threshold (float): + When the model predicts category of pixels of + the image, it will only provide predictions for + pixels that it is at least this much confident + about. All other pixels will be classified as + background. Default value is 0.5. + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py new file mode 100644 index 0000000000..66f1f19e76 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_action_recognition.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoActionRecognitionPredictionParams",}, +) + + +class VideoActionRecognitionPredictionParams(proto.Message): + r"""Prediction model parameters for Video Action Recognition. + + Attributes: + confidence_threshold (float): + The Model only returns predictions with at + least this confidence score. Default value is + 0.0 + max_predictions (int): + The model only returns up to that many top, + by confidence score, predictions per frame of + the video. If this number is very high, the + Model may return fewer predictions per frame. + Default value is 50. + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + max_predictions = proto.Field(proto.INT32, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py new file mode 100644 index 0000000000..bfe8df9f5c --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_classification.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoClassificationPredictionParams",}, +) + + +class VideoClassificationPredictionParams(proto.Message): + r"""Prediction model parameters for Video Classification. + + Attributes: + confidence_threshold (float): + The Model only returns predictions with at + least this confidence score. Default value is + 0.0 + max_predictions (int): + The Model only returns up to that many top, + by confidence score, predictions per instance. + If this number is very high, the Model may + return fewer predictions. Default value is + 10,000. + segment_classification (bool): + Set to true to request segment-level + classification. AI Platform returns labels and + their confidence scores for the entire time + segment of the video that user specified in the + input instance. Default value is true + shot_classification (bool): + Set to true to request shot-level + classification. AI Platform determines the + boundaries for each camera shot in the entire + time segment of the video that user specified in + the input instance. AI Platform then returns + labels and their confidence scores for each + detected shot, along with the start and end time + of the shot. + WARNING: Model evaluation is not done for this + classification type, the quality of it depends + on the training data, but there are no metrics + provided to describe that quality. + Default value is false + one_sec_interval_classification (bool): + Set to true to request classification for a + video at one-second intervals. AI Platform + returns labels and their confidence scores for + each second of the entire time segment of the + video that user specified in the input WARNING: + Model evaluation is not done for this + classification type, the quality of it depends + on the training data, but there are no metrics + provided to describe that quality. Default value + is false + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + max_predictions = proto.Field(proto.INT32, number=2) + + segment_classification = proto.Field(proto.BOOL, number=3) + + shot_classification = proto.Field(proto.BOOL, number=4) + + one_sec_interval_classification = proto.Field(proto.BOOL, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py new file mode 100644 index 0000000000..899de1050a --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/types/video_object_tracking.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.params", + manifest={"VideoObjectTrackingPredictionParams",}, +) + + +class VideoObjectTrackingPredictionParams(proto.Message): + r"""Prediction model parameters for Video Object Tracking. + + Attributes: + confidence_threshold (float): + The Model only returns predictions with at + least this confidence score. Default value is + 0.0 + max_predictions (int): + The model only returns up to that many top, + by confidence score, predictions per frame of + the video. If this number is very high, the + Model may return fewer predictions per frame. + Default value is 50. + min_bounding_box_size (float): + Only bounding boxes with shortest edge at + least that long as a relative value of video + frame size are returned. Default value is 0.0. + """ + + confidence_threshold = proto.Field(proto.FLOAT, number=1) + + max_predictions = proto.Field(proto.INT32, number=2) + + min_bounding_box_size = proto.Field(proto.FLOAT, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py new file mode 100644 index 0000000000..d8e2b782c2 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/__init__.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.classification import ( + ClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_object_detection import ( + ImageObjectDetectionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.image_segmentation import ( + ImageSegmentationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_classification import ( + TabularClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.tabular_regression import ( + TabularRegressionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_extraction import ( + TextExtractionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.text_sentiment import ( + TextSentimentPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_action_recognition import ( + VideoActionRecognitionPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_classification import ( + VideoClassificationPredictionResult, +) +from google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.video_object_tracking import ( + VideoObjectTrackingPredictionResult, +) + +__all__ = ( + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/py.typed b/google/cloud/aiplatform/v1/schema/predict/prediction/py.typed new file mode 100644 index 0000000000..472fa4d8cc --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-prediction package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py new file mode 100644 index 0000000000..91fae5a3b1 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/__init__.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types.classification import ClassificationPredictionResult +from .types.image_object_detection import ImageObjectDetectionPredictionResult +from .types.image_segmentation import ImageSegmentationPredictionResult +from .types.tabular_classification import TabularClassificationPredictionResult +from .types.tabular_regression import TabularRegressionPredictionResult +from .types.text_extraction import TextExtractionPredictionResult +from .types.text_sentiment import TextSentimentPredictionResult +from .types.video_action_recognition import VideoActionRecognitionPredictionResult +from .types.video_classification import VideoClassificationPredictionResult +from .types.video_object_tracking import VideoObjectTrackingPredictionResult + + +__all__ = ( + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", + "ClassificationPredictionResult", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/py.typed b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/py.typed new file mode 100644 index 0000000000..472fa4d8cc --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-predict-prediction package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py new file mode 100644 index 0000000000..a0fd2058e0 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/__init__.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .classification import ClassificationPredictionResult +from .image_object_detection import ImageObjectDetectionPredictionResult +from .image_segmentation import ImageSegmentationPredictionResult +from .tabular_classification import TabularClassificationPredictionResult +from .tabular_regression import TabularRegressionPredictionResult +from .text_extraction import TextExtractionPredictionResult +from .text_sentiment import TextSentimentPredictionResult +from .video_action_recognition import VideoActionRecognitionPredictionResult +from .video_classification import VideoClassificationPredictionResult +from .video_object_tracking import VideoObjectTrackingPredictionResult + +__all__ = ( + "ClassificationPredictionResult", + "ImageObjectDetectionPredictionResult", + "ImageSegmentationPredictionResult", + "TabularClassificationPredictionResult", + "TabularRegressionPredictionResult", + "TextExtractionPredictionResult", + "TextSentimentPredictionResult", + "VideoActionRecognitionPredictionResult", + "VideoClassificationPredictionResult", + "VideoObjectTrackingPredictionResult", +) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py new file mode 100644 index 0000000000..cfc8e2e602 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/classification.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ClassificationPredictionResult",}, +) + + +class ClassificationPredictionResult(proto.Message): + r"""Prediction output format for Image and Text Classification. + + Attributes: + ids (Sequence[int]): + The resource IDs of the AnnotationSpecs that + had been identified, ordered by the confidence + score descendingly. + display_names (Sequence[str]): + The display names of the AnnotationSpecs that + had been identified, order matches the IDs. + confidences (Sequence[float]): + The Model's confidences in correctness of the + predicted IDs, higher value means higher + confidence. Order matches the Ids. + """ + + ids = proto.RepeatedField(proto.INT64, number=1) + + display_names = proto.RepeatedField(proto.STRING, number=2) + + confidences = proto.RepeatedField(proto.FLOAT, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py new file mode 100644 index 0000000000..31d37010db --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_object_detection.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ImageObjectDetectionPredictionResult",}, +) + + +class ImageObjectDetectionPredictionResult(proto.Message): + r"""Prediction output format for Image Object Detection. + + Attributes: + ids (Sequence[int]): + The resource IDs of the AnnotationSpecs that + had been identified, ordered by the confidence + score descendingly. + display_names (Sequence[str]): + The display names of the AnnotationSpecs that + had been identified, order matches the IDs. + confidences (Sequence[float]): + The Model's confidences in correctness of the + predicted IDs, higher value means higher + confidence. Order matches the Ids. + bboxes (Sequence[google.protobuf.struct_pb2.ListValue]): + Bounding boxes, i.e. the rectangles over the image, that + pinpoint the found AnnotationSpecs. Given in order that + matches the IDs. Each bounding box is an array of 4 numbers + ``xMin``, ``xMax``, ``yMin``, and ``yMax``, which represent + the extremal coordinates of the box. They are relative to + the image size, and the point 0,0 is in the top left of the + image. + """ + + ids = proto.RepeatedField(proto.INT64, number=1) + + display_names = proto.RepeatedField(proto.STRING, number=2) + + confidences = proto.RepeatedField(proto.FLOAT, number=3) + + bboxes = proto.RepeatedField(proto.MESSAGE, number=4, message=struct.ListValue,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py new file mode 100644 index 0000000000..1261f19723 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/image_segmentation.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"ImageSegmentationPredictionResult",}, +) + + +class ImageSegmentationPredictionResult(proto.Message): + r"""Prediction output format for Image Segmentation. + + Attributes: + category_mask (str): + A PNG image where each pixel in the mask + represents the category in which the pixel in + the original image was predicted to belong to. + The size of this image will be the same as the + original image. The mapping between the + AnntoationSpec and the color can be found in + model's metadata. The model will choose the most + likely category and if none of the categories + reach the confidence threshold, the pixel will + be marked as background. + confidence_mask (str): + A one channel image which is encoded as an + 8bit lossless PNG. The size of the image will be + the same as the original image. For a specific + pixel, darker color means less confidence in + correctness of the cateogry in the categoryMask + for the corresponding pixel. Black means no + confidence and white means complete confidence. + """ + + category_mask = proto.Field(proto.STRING, number=1) + + confidence_mask = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py new file mode 100644 index 0000000000..7e78051467 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_classification.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TabularClassificationPredictionResult",}, +) + + +class TabularClassificationPredictionResult(proto.Message): + r"""Prediction output format for Tabular Classification. + + Attributes: + classes (Sequence[str]): + The name of the classes being classified, + contains all possible values of the target + column. + scores (Sequence[float]): + The model's confidence in each class being + correct, higher value means higher confidence. + The N-th score corresponds to the N-th class in + classes. + """ + + classes = proto.RepeatedField(proto.STRING, number=1) + + scores = proto.RepeatedField(proto.FLOAT, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py new file mode 100644 index 0000000000..c813f3e45c --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/tabular_regression.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TabularRegressionPredictionResult",}, +) + + +class TabularRegressionPredictionResult(proto.Message): + r"""Prediction output format for Tabular Regression. + + Attributes: + value (float): + The regression value. + lower_bound (float): + The lower bound of the prediction interval. + upper_bound (float): + The upper bound of the prediction interval. + """ + + value = proto.Field(proto.FLOAT, number=1) + + lower_bound = proto.Field(proto.FLOAT, number=2) + + upper_bound = proto.Field(proto.FLOAT, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py new file mode 100644 index 0000000000..201f10d08a --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_extraction.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TextExtractionPredictionResult",}, +) + + +class TextExtractionPredictionResult(proto.Message): + r"""Prediction output format for Text Extraction. + + Attributes: + ids (Sequence[int]): + The resource IDs of the AnnotationSpecs that + had been identified, ordered by the confidence + score descendingly. + display_names (Sequence[str]): + The display names of the AnnotationSpecs that + had been identified, order matches the IDs. + text_segment_start_offsets (Sequence[int]): + The start offsets, inclusive, of the text + segment in which the AnnotationSpec has been + identified. Expressed as a zero-based number of + characters as measured from the start of the + text snippet. + text_segment_end_offsets (Sequence[int]): + The end offsets, inclusive, of the text + segment in which the AnnotationSpec has been + identified. Expressed as a zero-based number of + characters as measured from the start of the + text snippet. + confidences (Sequence[float]): + The Model's confidences in correctness of the + predicted IDs, higher value means higher + confidence. Order matches the Ids. + """ + + ids = proto.RepeatedField(proto.INT64, number=1) + + display_names = proto.RepeatedField(proto.STRING, number=2) + + text_segment_start_offsets = proto.RepeatedField(proto.INT64, number=3) + + text_segment_end_offsets = proto.RepeatedField(proto.INT64, number=4) + + confidences = proto.RepeatedField(proto.FLOAT, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py new file mode 100644 index 0000000000..73c670f4ec --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/text_sentiment.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"TextSentimentPredictionResult",}, +) + + +class TextSentimentPredictionResult(proto.Message): + r"""Prediction output format for Text Sentiment + + Attributes: + sentiment (int): + The integer sentiment labels between 0 + (inclusive) and sentimentMax label (inclusive), + while 0 maps to the least positive sentiment and + sentimentMax maps to the most positive one. The + higher the score is, the more positive the + sentiment in the text snippet is. Note: + sentimentMax is an integer value between 1 + (inclusive) and 10 (inclusive). + """ + + sentiment = proto.Field(proto.INT32, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py new file mode 100644 index 0000000000..486853c63d --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_action_recognition.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import wrappers_pb2 as wrappers # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoActionRecognitionPredictionResult",}, +) + + +class VideoActionRecognitionPredictionResult(proto.Message): + r"""Prediction output format for Video Action Recognition. + + Attributes: + id (str): + The resource ID of the AnnotationSpec that + had been identified. + display_name (str): + The display name of the AnnotationSpec that + had been identified. + time_segment_start (google.protobuf.duration_pb2.Duration): + The beginning, inclusive, of the video's time + segment in which the AnnotationSpec has been + identified. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. + time_segment_end (google.protobuf.duration_pb2.Duration): + The end, exclusive, of the video's time + segment in which the AnnotationSpec has been + identified. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. + confidence (google.protobuf.wrappers_pb2.FloatValue): + The Model's confidence in correction of this + prediction, higher value means higher + confidence. + """ + + id = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, + ) + + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py new file mode 100644 index 0000000000..c043547d04 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_classification.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import wrappers_pb2 as wrappers # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoClassificationPredictionResult",}, +) + + +class VideoClassificationPredictionResult(proto.Message): + r"""Prediction output format for Video Classification. + + Attributes: + id (str): + The resource ID of the AnnotationSpec that + had been identified. + display_name (str): + The display name of the AnnotationSpec that + had been identified. + type_ (str): + The type of the prediction. The requested + types can be configured via parameters. This + will be one of - segment-classification + - shot-classification + - one-sec-interval-classification + time_segment_start (google.protobuf.duration_pb2.Duration): + The beginning, inclusive, of the video's time + segment in which the AnnotationSpec has been + identified. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. Note that for + 'segment-classification' prediction type, this + equals the original 'timeSegmentStart' from the + input instance, for other types it is the start + of a shot or a 1 second interval respectively. + time_segment_end (google.protobuf.duration_pb2.Duration): + The end, exclusive, of the video's time + segment in which the AnnotationSpec has been + identified. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. Note that for + 'segment-classification' prediction type, this + equals the original 'timeSegmentEnd' from the + input instance, for other types it is the end of + a shot or a 1 second interval respectively. + confidence (google.protobuf.wrappers_pb2.FloatValue): + The Model's confidence in correction of this + prediction, higher value means higher + confidence. + """ + + id = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + type_ = proto.Field(proto.STRING, number=3) + + time_segment_start = proto.Field( + proto.MESSAGE, number=4, message=duration.Duration, + ) + + time_segment_end = proto.Field(proto.MESSAGE, number=5, message=duration.Duration,) + + confidence = proto.Field(proto.MESSAGE, number=6, message=wrappers.FloatValue,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py new file mode 100644 index 0000000000..d1b515a895 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/types/video_object_tracking.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import wrappers_pb2 as wrappers # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.predict.prediction", + manifest={"VideoObjectTrackingPredictionResult",}, +) + + +class VideoObjectTrackingPredictionResult(proto.Message): + r"""Prediction output format for Video Object Tracking. + + Attributes: + id (str): + The resource ID of the AnnotationSpec that + had been identified. + display_name (str): + The display name of the AnnotationSpec that + had been identified. + time_segment_start (google.protobuf.duration_pb2.Duration): + The beginning, inclusive, of the video's time + segment in which the object instance has been + detected. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. + time_segment_end (google.protobuf.duration_pb2.Duration): + The end, inclusive, of the video's time + segment in which the object instance has been + detected. Expressed as a number of seconds as + measured from the start of the video, with + fractions up to a microsecond precision, and + with "s" appended at the end. + confidence (google.protobuf.wrappers_pb2.FloatValue): + The Model's confidence in correction of this + prediction, higher value means higher + confidence. + frames (Sequence[google.cloud.aiplatform.v1.schema.predict.prediction_v1.types.VideoObjectTrackingPredictionResult.Frame]): + All of the frames of the video in which a + single object instance has been detected. The + bounding boxes in the frames identify the same + object. + """ + + class Frame(proto.Message): + r"""The fields ``xMin``, ``xMax``, ``yMin``, and ``yMax`` refer to a + bounding box, i.e. the rectangle over the video frame pinpointing + the found AnnotationSpec. The coordinates are relative to the frame + size, and the point 0,0 is in the top left of the frame. + + Attributes: + time_offset (google.protobuf.duration_pb2.Duration): + A time (frame) of a video in which the object + has been detected. Expressed as a number of + seconds as measured from the start of the video, + with fractions up to a microsecond precision, + and with "s" appended at the end. + x_min (google.protobuf.wrappers_pb2.FloatValue): + The leftmost coordinate of the bounding box. + x_max (google.protobuf.wrappers_pb2.FloatValue): + The rightmost coordinate of the bounding box. + y_min (google.protobuf.wrappers_pb2.FloatValue): + The topmost coordinate of the bounding box. + y_max (google.protobuf.wrappers_pb2.FloatValue): + The bottommost coordinate of the bounding + box. + """ + + time_offset = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + + x_min = proto.Field(proto.MESSAGE, number=2, message=wrappers.FloatValue,) + + x_max = proto.Field(proto.MESSAGE, number=3, message=wrappers.FloatValue,) + + y_min = proto.Field(proto.MESSAGE, number=4, message=wrappers.FloatValue,) + + y_max = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + + id = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + time_segment_start = proto.Field( + proto.MESSAGE, number=3, message=duration.Duration, + ) + + time_segment_end = proto.Field(proto.MESSAGE, number=4, message=duration.Duration,) + + confidence = proto.Field(proto.MESSAGE, number=5, message=wrappers.FloatValue,) + + frames = proto.RepeatedField(proto.MESSAGE, number=6, message=Frame,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py new file mode 100644 index 0000000000..f8620bb25d --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/__init__.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_classification import ( + AutoMlImageClassificationMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetection, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_object_detection import ( + AutoMlImageObjectDetectionMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentation, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_image_segmentation import ( + AutoMlImageSegmentationMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTables, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTablesInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_tables import ( + AutoMlTablesMetadata, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( + AutoMlTextClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_classification import ( + AutoMlTextClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( + AutoMlTextExtraction, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_extraction import ( + AutoMlTextExtractionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( + AutoMlTextSentiment, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_text_sentiment import ( + AutoMlTextSentimentInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognition, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_action_recognition import ( + AutoMlVideoActionRecognitionInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( + AutoMlVideoClassification, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_classification import ( + AutoMlVideoClassificationInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTracking, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.automl_video_object_tracking import ( + AutoMlVideoObjectTrackingInputs, +) +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.export_evaluated_data_items_config import ( + ExportEvaluatedDataItemsConfig, +) + +__all__ = ( + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", +) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/py.typed b/google/cloud/aiplatform/v1/schema/trainingjob/definition/py.typed new file mode 100644 index 0000000000..1a9d2972a0 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-trainingjob-definition package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py new file mode 100644 index 0000000000..34958e5add --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/__init__.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .types.automl_image_classification import AutoMlImageClassification +from .types.automl_image_classification import AutoMlImageClassificationInputs +from .types.automl_image_classification import AutoMlImageClassificationMetadata +from .types.automl_image_object_detection import AutoMlImageObjectDetection +from .types.automl_image_object_detection import AutoMlImageObjectDetectionInputs +from .types.automl_image_object_detection import AutoMlImageObjectDetectionMetadata +from .types.automl_image_segmentation import AutoMlImageSegmentation +from .types.automl_image_segmentation import AutoMlImageSegmentationInputs +from .types.automl_image_segmentation import AutoMlImageSegmentationMetadata +from .types.automl_tables import AutoMlTables +from .types.automl_tables import AutoMlTablesInputs +from .types.automl_tables import AutoMlTablesMetadata +from .types.automl_text_classification import AutoMlTextClassification +from .types.automl_text_classification import AutoMlTextClassificationInputs +from .types.automl_text_extraction import AutoMlTextExtraction +from .types.automl_text_extraction import AutoMlTextExtractionInputs +from .types.automl_text_sentiment import AutoMlTextSentiment +from .types.automl_text_sentiment import AutoMlTextSentimentInputs +from .types.automl_video_action_recognition import AutoMlVideoActionRecognition +from .types.automl_video_action_recognition import AutoMlVideoActionRecognitionInputs +from .types.automl_video_classification import AutoMlVideoClassification +from .types.automl_video_classification import AutoMlVideoClassificationInputs +from .types.automl_video_object_tracking import AutoMlVideoObjectTracking +from .types.automl_video_object_tracking import AutoMlVideoObjectTrackingInputs +from .types.export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig + + +__all__ = ( + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", + "ExportEvaluatedDataItemsConfig", + "AutoMlImageClassification", +) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/py.typed b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/py.typed new file mode 100644 index 0000000000..1a9d2972a0 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform-v1-schema-trainingjob-definition package uses inline types. diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py new file mode 100644 index 0000000000..8a60b2e36c --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/__init__.py @@ -0,0 +1,90 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .automl_image_classification import ( + AutoMlImageClassification, + AutoMlImageClassificationInputs, + AutoMlImageClassificationMetadata, +) +from .automl_image_object_detection import ( + AutoMlImageObjectDetection, + AutoMlImageObjectDetectionInputs, + AutoMlImageObjectDetectionMetadata, +) +from .automl_image_segmentation import ( + AutoMlImageSegmentation, + AutoMlImageSegmentationInputs, + AutoMlImageSegmentationMetadata, +) +from .export_evaluated_data_items_config import ExportEvaluatedDataItemsConfig +from .automl_tables import ( + AutoMlTables, + AutoMlTablesInputs, + AutoMlTablesMetadata, +) +from .automl_text_classification import ( + AutoMlTextClassification, + AutoMlTextClassificationInputs, +) +from .automl_text_extraction import ( + AutoMlTextExtraction, + AutoMlTextExtractionInputs, +) +from .automl_text_sentiment import ( + AutoMlTextSentiment, + AutoMlTextSentimentInputs, +) +from .automl_video_action_recognition import ( + AutoMlVideoActionRecognition, + AutoMlVideoActionRecognitionInputs, +) +from .automl_video_classification import ( + AutoMlVideoClassification, + AutoMlVideoClassificationInputs, +) +from .automl_video_object_tracking import ( + AutoMlVideoObjectTracking, + AutoMlVideoObjectTrackingInputs, +) + +__all__ = ( + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + "ExportEvaluatedDataItemsConfig", + "AutoMlTables", + "AutoMlTablesInputs", + "AutoMlTablesMetadata", + "AutoMlTextClassification", + "AutoMlTextClassificationInputs", + "AutoMlTextExtraction", + "AutoMlTextExtractionInputs", + "AutoMlTextSentiment", + "AutoMlTextSentimentInputs", + "AutoMlVideoActionRecognition", + "AutoMlVideoActionRecognitionInputs", + "AutoMlVideoClassification", + "AutoMlVideoClassificationInputs", + "AutoMlVideoObjectTracking", + "AutoMlVideoObjectTrackingInputs", +) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py new file mode 100644 index 0000000000..f7e13c60b7 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_classification.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={ + "AutoMlImageClassification", + "AutoMlImageClassificationInputs", + "AutoMlImageClassificationMetadata", + }, +) + + +class AutoMlImageClassification(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Image + Classification Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageClassificationInputs): + The input parameters of this TrainingJob. + metadata (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageClassificationMetadata): + The metadata information. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageClassificationInputs", + ) + + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageClassificationMetadata", + ) + + +class AutoMlImageClassificationInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageClassificationInputs.ModelType): + + base_model_id (str): + The ID of the ``base`` model. If it is specified, the new + model will be trained based on the ``base`` model. + Otherwise, the new model will be trained from scratch. The + ``base`` model must be in the same Project and Location as + the new Model to train, and have the same modelType. + budget_milli_node_hours (int): + The training budget of creating this model, expressed in + milli node hours i.e. 1,000 value in this field means 1 node + hour. The actual metadata.costMilliNodeHours will be equal + or less than this value. If further model training ceases to + provide any improvements, it will stop without using the + full budget and the metadata.successfulStopReason will be + ``model-converged``. Note, node_hour = actual_hour \* + number_of_nodes_involved. For modelType + ``cloud``\ (default), the budget must be between 8,000 and + 800,000 milli node hours, inclusive. The default value is + 192,000 which represents one day in wall time, considering 8 + nodes are used. For model types ``mobile-tf-low-latency-1``, + ``mobile-tf-versatile-1``, ``mobile-tf-high-accuracy-1``, + the training budget must be between 1,000 and 100,000 milli + node hours, inclusive. The default value is 24,000 which + represents one day in wall time on a single node that is + used. + disable_early_stopping (bool): + Use the entire training budget. This disables + the early stopping feature. When false the early + stopping feature is enabled, which means that + AutoML Image Classification might stop training + before the entire training budget has been used. + multi_label (bool): + If false, a single-label (multi-class) Model + will be trained (i.e. assuming that for each + image just up to one annotation may be + applicable). If true, a multi-label Model will + be trained (i.e. assuming that for each image + multiple annotations may be applicable). + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD = 1 + MOBILE_TF_LOW_LATENCY_1 = 2 + MOBILE_TF_VERSATILE_1 = 3 + MOBILE_TF_HIGH_ACCURACY_1 = 4 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + base_model_id = proto.Field(proto.STRING, number=2) + + budget_milli_node_hours = proto.Field(proto.INT64, number=3) + + disable_early_stopping = proto.Field(proto.BOOL, number=4) + + multi_label = proto.Field(proto.BOOL, number=5) + + +class AutoMlImageClassificationMetadata(proto.Message): + r""" + + Attributes: + cost_milli_node_hours (int): + The actual training cost of creating this + model, expressed in milli node hours, i.e. 1,000 + value in this field means 1 node hour. + Guaranteed to not exceed + inputs.budgetMilliNodeHours. + successful_stop_reason (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageClassificationMetadata.SuccessfulStopReason): + For successful job completions, this is the + reason why the job has finished. + """ + + class SuccessfulStopReason(proto.Enum): + r"""""" + SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 + BUDGET_REACHED = 1 + MODEL_CONVERGED = 2 + + cost_milli_node_hours = proto.Field(proto.INT64, number=1) + + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py new file mode 100644 index 0000000000..1c2c9f83b7 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_object_detection.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={ + "AutoMlImageObjectDetection", + "AutoMlImageObjectDetectionInputs", + "AutoMlImageObjectDetectionMetadata", + }, +) + + +class AutoMlImageObjectDetection(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Image Object + Detection Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageObjectDetectionInputs): + The input parameters of this TrainingJob. + metadata (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageObjectDetectionMetadata): + The metadata information + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageObjectDetectionInputs", + ) + + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageObjectDetectionMetadata", + ) + + +class AutoMlImageObjectDetectionInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageObjectDetectionInputs.ModelType): + + budget_milli_node_hours (int): + The training budget of creating this model, expressed in + milli node hours i.e. 1,000 value in this field means 1 node + hour. The actual metadata.costMilliNodeHours will be equal + or less than this value. If further model training ceases to + provide any improvements, it will stop without using the + full budget and the metadata.successfulStopReason will be + ``model-converged``. Note, node_hour = actual_hour \* + number_of_nodes_involved. For modelType + ``cloud``\ (default), the budget must be between 20,000 and + 900,000 milli node hours, inclusive. The default value is + 216,000 which represents one day in wall time, considering 9 + nodes are used. For model types ``mobile-tf-low-latency-1``, + ``mobile-tf-versatile-1``, ``mobile-tf-high-accuracy-1`` the + training budget must be between 1,000 and 100,000 milli node + hours, inclusive. The default value is 24,000 which + represents one day in wall time on a single node that is + used. + disable_early_stopping (bool): + Use the entire training budget. This disables + the early stopping feature. When false the early + stopping feature is enabled, which means that + AutoML Image Object Detection might stop + training before the entire training budget has + been used. + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD_HIGH_ACCURACY_1 = 1 + CLOUD_LOW_LATENCY_1 = 2 + MOBILE_TF_LOW_LATENCY_1 = 3 + MOBILE_TF_VERSATILE_1 = 4 + MOBILE_TF_HIGH_ACCURACY_1 = 5 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + budget_milli_node_hours = proto.Field(proto.INT64, number=2) + + disable_early_stopping = proto.Field(proto.BOOL, number=3) + + +class AutoMlImageObjectDetectionMetadata(proto.Message): + r""" + + Attributes: + cost_milli_node_hours (int): + The actual training cost of creating this + model, expressed in milli node hours, i.e. 1,000 + value in this field means 1 node hour. + Guaranteed to not exceed + inputs.budgetMilliNodeHours. + successful_stop_reason (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageObjectDetectionMetadata.SuccessfulStopReason): + For successful job completions, this is the + reason why the job has finished. + """ + + class SuccessfulStopReason(proto.Enum): + r"""""" + SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 + BUDGET_REACHED = 1 + MODEL_CONVERGED = 2 + + cost_milli_node_hours = proto.Field(proto.INT64, number=1) + + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py new file mode 100644 index 0000000000..a81103657e --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_image_segmentation.py @@ -0,0 +1,122 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={ + "AutoMlImageSegmentation", + "AutoMlImageSegmentationInputs", + "AutoMlImageSegmentationMetadata", + }, +) + + +class AutoMlImageSegmentation(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Image + Segmentation Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageSegmentationInputs): + The input parameters of this TrainingJob. + metadata (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageSegmentationMetadata): + The metadata information. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlImageSegmentationInputs", + ) + + metadata = proto.Field( + proto.MESSAGE, number=2, message="AutoMlImageSegmentationMetadata", + ) + + +class AutoMlImageSegmentationInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageSegmentationInputs.ModelType): + + budget_milli_node_hours (int): + The training budget of creating this model, expressed in + milli node hours i.e. 1,000 value in this field means 1 node + hour. The actual metadata.costMilliNodeHours will be equal + or less than this value. If further model training ceases to + provide any improvements, it will stop without using the + full budget and the metadata.successfulStopReason will be + ``model-converged``. Note, node_hour = actual_hour \* + number_of_nodes_involved. Or actaul_wall_clock_hours = + train_budget_milli_node_hours / (number_of_nodes_involved \* + 1000) For modelType ``cloud-high-accuracy-1``\ (default), + the budget must be between 20,000 and 2,000,000 milli node + hours, inclusive. The default value is 192,000 which + represents one day in wall time (1000 milli \* 24 hours \* 8 + nodes). + base_model_id (str): + The ID of the ``base`` model. If it is specified, the new + model will be trained based on the ``base`` model. + Otherwise, the new model will be trained from scratch. The + ``base`` model must be in the same Project and Location as + the new Model to train, and have the same modelType. + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD_HIGH_ACCURACY_1 = 1 + CLOUD_LOW_ACCURACY_1 = 2 + MOBILE_TF_LOW_LATENCY_1 = 3 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + budget_milli_node_hours = proto.Field(proto.INT64, number=2) + + base_model_id = proto.Field(proto.STRING, number=3) + + +class AutoMlImageSegmentationMetadata(proto.Message): + r""" + + Attributes: + cost_milli_node_hours (int): + The actual training cost of creating this + model, expressed in milli node hours, i.e. 1,000 + value in this field means 1 node hour. + Guaranteed to not exceed + inputs.budgetMilliNodeHours. + successful_stop_reason (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlImageSegmentationMetadata.SuccessfulStopReason): + For successful job completions, this is the + reason why the job has finished. + """ + + class SuccessfulStopReason(proto.Enum): + r"""""" + SUCCESSFUL_STOP_REASON_UNSPECIFIED = 0 + BUDGET_REACHED = 1 + MODEL_CONVERGED = 2 + + cost_milli_node_hours = proto.Field(proto.INT64, number=1) + + successful_stop_reason = proto.Field( + proto.ENUM, number=2, enum=SuccessfulStopReason, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py new file mode 100644 index 0000000000..1c3d0c8da7 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_tables.py @@ -0,0 +1,447 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types import ( + export_evaluated_data_items_config as gcastd_export_evaluated_data_items_config, +) + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTables", "AutoMlTablesInputs", "AutoMlTablesMetadata",}, +) + + +class AutoMlTables(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Tables Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs): + The input parameters of this TrainingJob. + metadata (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesMetadata): + The metadata information. + """ + + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTablesInputs",) + + metadata = proto.Field(proto.MESSAGE, number=2, message="AutoMlTablesMetadata",) + + +class AutoMlTablesInputs(proto.Message): + r""" + + Attributes: + optimization_objective_recall_value (float): + Required when optimization_objective is + "maximize-precision-at-recall". Must be between 0 and 1, + inclusive. + optimization_objective_precision_value (float): + Required when optimization_objective is + "maximize-recall-at-precision". Must be between 0 and 1, + inclusive. + prediction_type (str): + The type of prediction the Model is to + produce. "classification" - Predict one out of + multiple target values is + picked for each row. + "regression" - Predict a value based on its + relation to other values. This + type is available only to columns that contain + semantically numeric values, i.e. integers or + floating point number, even if + stored as e.g. strings. + target_column (str): + The column name of the target column that the + model is to predict. + transformations (Sequence[google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation]): + Each transformation will apply transform + function to given input column. And the result + will be used for training. When creating + transformation for BigQuery Struct column, the + column should be flattened using "." as the + delimiter. + optimization_objective (str): + Objective function the model is optimizing + towards. The training process creates a model + that maximizes/minimizes the value of the + objective function over the validation set. + + The supported optimization objectives depend on + the prediction type. If the field is not set, a + default objective function is used. + classification (binary): + "maximize-au-roc" (default) - Maximize the + area under the receiver + operating characteristic (ROC) curve. + "minimize-log-loss" - Minimize log loss. + "maximize-au-prc" - Maximize the area under + the precision-recall curve. "maximize- + precision-at-recall" - Maximize precision for a + specified + recall value. "maximize-recall-at-precision" - + Maximize recall for a specified + precision value. + classification (multi-class): + "minimize-log-loss" (default) - Minimize log + loss. + regression: + "minimize-rmse" (default) - Minimize root- + mean-squared error (RMSE). "minimize-mae" - + Minimize mean-absolute error (MAE). "minimize- + rmsle" - Minimize root-mean-squared log error + (RMSLE). + train_budget_milli_node_hours (int): + Required. The train budget of creating this + model, expressed in milli node hours i.e. 1,000 + value in this field means 1 node hour. + The training cost of the model will not exceed + this budget. The final cost will be attempted to + be close to the budget, though may end up being + (even) noticeably smaller - at the backend's + discretion. This especially may happen when + further model training ceases to provide any + improvements. + If the budget is set to a value known to be + insufficient to train a model for the given + dataset, the training won't be attempted and + will error. + + The train budget must be between 1,000 and + 72,000 milli node hours, inclusive. + disable_early_stopping (bool): + Use the entire training budget. This disables + the early stopping feature. By default, the + early stopping feature is enabled, which means + that AutoML Tables might stop training before + the entire training budget has been used. + weight_column_name (str): + Column name that should be used as the weight + column. Higher values in this column give more + importance to the row during model training. The + column must have numeric values between 0 and + 10000 inclusively; 0 means the row is ignored + for training. If weight column field is not set, + then all rows are assumed to have equal weight + of 1. + export_evaluated_data_items_config (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.ExportEvaluatedDataItemsConfig): + Configuration for exporting test set + predictions to a BigQuery table. If this + configuration is absent, then the export is not + performed. + """ + + class Transformation(proto.Message): + r""" + + Attributes: + auto (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.AutoTransformation): + + numeric (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.NumericTransformation): + + categorical (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.CategoricalTransformation): + + timestamp (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TimestampTransformation): + + text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextTransformation): + + repeated_numeric (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.NumericArrayTransformation): + + repeated_categorical (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.CategoricalArrayTransformation): + + repeated_text (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): + + """ + + class AutoTransformation(proto.Message): + r"""Training pipeline will infer the proper transformation based + on the statistic of dataset. + + Attributes: + column_name (str): + + """ + + column_name = proto.Field(proto.STRING, number=1) + + class NumericTransformation(proto.Message): + r"""Training pipeline will perform following transformation functions. + + - The value converted to float32. + - The z_score of the value. + - log(value+1) when the value is greater than or equal to 0. + Otherwise, this transformation is not applied and the value is + considered a missing value. + - z_score of log(value+1) when the value is greater than or equal + to 0. Otherwise, this transformation is not applied and the value + is considered a missing value. + - A boolean value that indicates whether the value is valid. + + Attributes: + column_name (str): + + invalid_values_allowed (bool): + If invalid values is allowed, the training + pipeline will create a boolean feature that + indicated whether the value is valid. Otherwise, + the training pipeline will discard the input row + from trainining data. + """ + + column_name = proto.Field(proto.STRING, number=1) + + invalid_values_allowed = proto.Field(proto.BOOL, number=2) + + class CategoricalTransformation(proto.Message): + r"""Training pipeline will perform following transformation functions. + + - The categorical string as is--no change to case, punctuation, + spelling, tense, and so on. + - Convert the category name to a dictionary lookup index and + generate an embedding for each index. + - Categories that appear less than 5 times in the training dataset + are treated as the "unknown" category. The "unknown" category + gets its own special lookup index and resulting embedding. + + Attributes: + column_name (str): + + """ + + column_name = proto.Field(proto.STRING, number=1) + + class TimestampTransformation(proto.Message): + r"""Training pipeline will perform following transformation functions. + + - Apply the transformation functions for Numerical columns. + - Determine the year, month, day,and weekday. Treat each value from + the + - timestamp as a Categorical column. + - Invalid numerical values (for example, values that fall outside + of a typical timestamp range, or are extreme values) receive no + special treatment and are not removed. + + Attributes: + column_name (str): + + time_format (str): + The format in which that time field is expressed. The + time_format must either be one of: + + - ``unix-seconds`` + - ``unix-milliseconds`` + - ``unix-microseconds`` + - ``unix-nanoseconds`` (for respectively number of seconds, + milliseconds, microseconds and nanoseconds since start of + the Unix epoch); or be written in ``strftime`` syntax. If + time_format is not set, then the default format is RFC + 3339 ``date-time`` format, where ``time-offset`` = + ``"Z"`` (e.g. 1985-04-12T23:20:50.52Z) + invalid_values_allowed (bool): + If invalid values is allowed, the training + pipeline will create a boolean feature that + indicated whether the value is valid. Otherwise, + the training pipeline will discard the input row + from trainining data. + """ + + column_name = proto.Field(proto.STRING, number=1) + + time_format = proto.Field(proto.STRING, number=2) + + invalid_values_allowed = proto.Field(proto.BOOL, number=3) + + class TextTransformation(proto.Message): + r"""Training pipeline will perform following transformation functions. + + - The text as is--no change to case, punctuation, spelling, tense, + and so on. + - Tokenize text to words. Convert each words to a dictionary lookup + index and generate an embedding for each index. Combine the + embedding of all elements into a single embedding using the mean. + - Tokenization is based on unicode script boundaries. + - Missing values get their own lookup index and resulting + embedding. + - Stop-words receive no special treatment and are not removed. + + Attributes: + column_name (str): + + """ + + column_name = proto.Field(proto.STRING, number=1) + + class NumericArrayTransformation(proto.Message): + r"""Treats the column as numerical array and performs following + transformation functions. + + - All transformations for Numerical types applied to the average of + the all elements. + - The average of empty arrays is treated as zero. + + Attributes: + column_name (str): + + invalid_values_allowed (bool): + If invalid values is allowed, the training + pipeline will create a boolean feature that + indicated whether the value is valid. Otherwise, + the training pipeline will discard the input row + from trainining data. + """ + + column_name = proto.Field(proto.STRING, number=1) + + invalid_values_allowed = proto.Field(proto.BOOL, number=2) + + class CategoricalArrayTransformation(proto.Message): + r"""Treats the column as categorical array and performs following + transformation functions. + + - For each element in the array, convert the category name to a + dictionary lookup index and generate an embedding for each index. + Combine the embedding of all elements into a single embedding + using the mean. + - Empty arrays treated as an embedding of zeroes. + + Attributes: + column_name (str): + + """ + + column_name = proto.Field(proto.STRING, number=1) + + class TextArrayTransformation(proto.Message): + r"""Treats the column as text array and performs following + transformation functions. + + - Concatenate all text values in the array into a single text value + using a space (" ") as a delimiter, and then treat the result as + a single text value. Apply the transformations for Text columns. + - Empty arrays treated as an empty text. + + Attributes: + column_name (str): + + """ + + column_name = proto.Field(proto.STRING, number=1) + + auto = proto.Field( + proto.MESSAGE, + number=1, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.AutoTransformation", + ) + + numeric = proto.Field( + proto.MESSAGE, + number=2, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericTransformation", + ) + + categorical = proto.Field( + proto.MESSAGE, + number=3, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalTransformation", + ) + + timestamp = proto.Field( + proto.MESSAGE, + number=4, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TimestampTransformation", + ) + + text = proto.Field( + proto.MESSAGE, + number=5, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextTransformation", + ) + + repeated_numeric = proto.Field( + proto.MESSAGE, + number=6, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.NumericArrayTransformation", + ) + + repeated_categorical = proto.Field( + proto.MESSAGE, + number=7, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.CategoricalArrayTransformation", + ) + + repeated_text = proto.Field( + proto.MESSAGE, + number=8, + oneof="transformation_detail", + message="AutoMlTablesInputs.Transformation.TextArrayTransformation", + ) + + optimization_objective_recall_value = proto.Field( + proto.FLOAT, number=5, oneof="additional_optimization_objective_config" + ) + + optimization_objective_precision_value = proto.Field( + proto.FLOAT, number=6, oneof="additional_optimization_objective_config" + ) + + prediction_type = proto.Field(proto.STRING, number=1) + + target_column = proto.Field(proto.STRING, number=2) + + transformations = proto.RepeatedField( + proto.MESSAGE, number=3, message=Transformation, + ) + + optimization_objective = proto.Field(proto.STRING, number=4) + + train_budget_milli_node_hours = proto.Field(proto.INT64, number=7) + + disable_early_stopping = proto.Field(proto.BOOL, number=8) + + weight_column_name = proto.Field(proto.STRING, number=9) + + export_evaluated_data_items_config = proto.Field( + proto.MESSAGE, + number=10, + message=gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig, + ) + + +class AutoMlTablesMetadata(proto.Message): + r"""Model metadata specific to AutoML Tables. + + Attributes: + train_cost_milli_node_hours (int): + Output only. The actual training cost of the + model, expressed in milli node hours, i.e. 1,000 + value in this field means 1 node hour. + Guaranteed to not exceed the train budget. + """ + + train_cost_milli_node_hours = proto.Field(proto.INT64, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py new file mode 100644 index 0000000000..205deaf375 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_classification.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextClassification", "AutoMlTextClassificationInputs",}, +) + + +class AutoMlTextClassification(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Text + Classification Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTextClassificationInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlTextClassificationInputs", + ) + + +class AutoMlTextClassificationInputs(proto.Message): + r""" + + Attributes: + multi_label (bool): + + """ + + multi_label = proto.Field(proto.BOOL, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py new file mode 100644 index 0000000000..fad28847af --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_extraction.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextExtraction", "AutoMlTextExtractionInputs",}, +) + + +class AutoMlTextExtraction(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Text + Extraction Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTextExtractionInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextExtractionInputs",) + + +class AutoMlTextExtractionInputs(proto.Message): + r"""""" + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py new file mode 100644 index 0000000000..ca80a44d1d --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_text_sentiment.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlTextSentiment", "AutoMlTextSentimentInputs",}, +) + + +class AutoMlTextSentiment(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Text + Sentiment Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlTextSentimentInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field(proto.MESSAGE, number=1, message="AutoMlTextSentimentInputs",) + + +class AutoMlTextSentimentInputs(proto.Message): + r""" + + Attributes: + sentiment_max (int): + A sentiment is expressed as an integer + ordinal, where higher value means a more + positive sentiment. The range of sentiments that + will be used is between 0 and sentimentMax + (inclusive on both ends), and all the values in + the range must be represented in the dataset + before a model can be created. + Only the Annotations with this sentimentMax will + be used for training. sentimentMax value must be + between 1 and 10 (inclusive). + """ + + sentiment_max = proto.Field(proto.INT32, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py new file mode 100644 index 0000000000..1a20a6d725 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_action_recognition.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoActionRecognition", "AutoMlVideoActionRecognitionInputs",}, +) + + +class AutoMlVideoActionRecognition(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Video Action + Recognition Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoActionRecognitionInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoActionRecognitionInputs", + ) + + +class AutoMlVideoActionRecognitionInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoActionRecognitionInputs.ModelType): + + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD = 1 + MOBILE_VERSATILE_1 = 2 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py new file mode 100644 index 0000000000..ba7f2d5b21 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_classification.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoClassification", "AutoMlVideoClassificationInputs",}, +) + + +class AutoMlVideoClassification(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Video + Classification Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoClassificationInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoClassificationInputs", + ) + + +class AutoMlVideoClassificationInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoClassificationInputs.ModelType): + + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD = 1 + MOBILE_VERSATILE_1 = 2 + MOBILE_JETSON_VERSATILE_1 = 3 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py new file mode 100644 index 0000000000..0ecb1113d9 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/automl_video_object_tracking.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"AutoMlVideoObjectTracking", "AutoMlVideoObjectTrackingInputs",}, +) + + +class AutoMlVideoObjectTracking(proto.Message): + r"""A TrainingJob that trains and uploads an AutoML Video + ObjectTracking Model. + + Attributes: + inputs (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoObjectTrackingInputs): + The input parameters of this TrainingJob. + """ + + inputs = proto.Field( + proto.MESSAGE, number=1, message="AutoMlVideoObjectTrackingInputs", + ) + + +class AutoMlVideoObjectTrackingInputs(proto.Message): + r""" + + Attributes: + model_type (google.cloud.aiplatform.v1.schema.trainingjob.definition_v1.types.AutoMlVideoObjectTrackingInputs.ModelType): + + """ + + class ModelType(proto.Enum): + r"""""" + MODEL_TYPE_UNSPECIFIED = 0 + CLOUD = 1 + MOBILE_VERSATILE_1 = 2 + MOBILE_CORAL_VERSATILE_1 = 3 + MOBILE_CORAL_LOW_LATENCY_1 = 4 + MOBILE_JETSON_VERSATILE_1 = 5 + MOBILE_JETSON_LOW_LATENCY_1 = 6 + + model_type = proto.Field(proto.ENUM, number=1, enum=ModelType,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py new file mode 100644 index 0000000000..dc8a629412 --- /dev/null +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/types/export_evaluated_data_items_config.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1.schema.trainingjob.definition", + manifest={"ExportEvaluatedDataItemsConfig",}, +) + + +class ExportEvaluatedDataItemsConfig(proto.Message): + r"""Configuration for exporting test set predictions to a + BigQuery table. + + Attributes: + destination_bigquery_uri (str): + URI of desired destination BigQuery table. Expected format: + bq://:: + + If not specified, then results are exported to the following + auto-created BigQuery table: + + :export_evaluated_examples__.evaluated_examples + override_existing_table (bool): + If true and an export destination is + specified, then the contents of the destination + are overwritten. Otherwise, if the export + destination already exists, then the export + operation fails. + """ + + destination_bigquery_uri = proto.Field(proto.STRING, number=1) + + override_existing_table = proto.Field(proto.BOOL, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py index 3160c08e1d..041fe6cdb1 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/types/__init__.py @@ -25,7 +25,6 @@ from .video_classification import VideoClassificationPredictionInstance from .video_object_tracking import VideoObjectTrackingPredictionInstance - __all__ = ( "ImageClassificationPredictionInstance", "ImageObjectDetectionPredictionInstance", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py index 39202720fa..2f2c29bba5 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/types/__init__.py @@ -22,7 +22,6 @@ from .video_classification import VideoClassificationPredictionParams from .video_object_tracking import VideoObjectTrackingPredictionParams - __all__ = ( "ImageClassificationPredictionParams", "ImageObjectDetectionPredictionParams", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py index 2d6c8a98d3..5ec1ed095e 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/__init__.py @@ -27,7 +27,6 @@ from .video_classification import VideoClassificationPredictionResult from .video_object_tracking import VideoObjectTrackingPredictionResult - __all__ = ( "ClassificationPredictionResult", "ImageObjectDetectionPredictionResult", diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py index 1bf5002c2a..3d0f7f1f76 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/image_object_detection.py @@ -42,7 +42,7 @@ class ImageObjectDetectionPredictionResult(proto.Message): The Model's confidences in correctness of the predicted IDs, higher value means higher confidence. Order matches the Ids. - bboxes (Sequence[~.struct.ListValue]): + bboxes (Sequence[google.protobuf.struct_pb2.ListValue]): Bounding boxes, i.e. the rectangles over the image, that pinpoint the found AnnotationSpecs. Given in order that matches the IDs. Each bounding box is an array of 4 numbers diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py index 39ef21bf21..f31b95a18f 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py @@ -17,11 +17,6 @@ import proto # type: ignore -# DO NOT OVERWRITE FOLLOWING LINE: it was manually edited. -from google.cloud.aiplatform.v1beta1.schema.predict.instance import ( - TextSentimentPredictionInstance, -) - __protobuf__ = proto.module( package="google.cloud.aiplatform.v1beta1.schema.predict.prediction", @@ -30,39 +25,21 @@ class TextSentimentPredictionResult(proto.Message): - r"""Represents a line of JSONL in the text sentiment batch - prediction output file. This is a hack to allow printing of - integer values. + r"""Prediction output format for Text Sentiment Attributes: - instance (~.gcaspi_text_sentiment.TextSentimentPredictionInstance): - User's input instance. - prediction (~.gcaspp_text_sentiment.TextSentimentPredictionResult.Prediction): - The prediction result. + sentiment (int): + The integer sentiment labels between 0 + (inclusive) and sentimentMax label (inclusive), + while 0 maps to the least positive sentiment and + sentimentMax maps to the most positive one. The + higher the score is, the more positive the + sentiment in the text snippet is. Note: + sentimentMax is an integer value between 1 + (inclusive) and 10 (inclusive). """ - class Prediction(proto.Message): - r"""Prediction output format for Text Sentiment. - - Attributes: - sentiment (int): - The integer sentiment labels between 0 - (inclusive) and sentimentMax label (inclusive), - while 0 maps to the least positive sentiment and - sentimentMax maps to the most positive one. The - higher the score is, the more positive the - sentiment in the text snippet is. Note: - sentimentMax is an integer value between 1 - (inclusive) and 10 (inclusive). - """ - - sentiment = proto.Field(proto.INT32, number=1) - - instance = proto.Field( - proto.MESSAGE, number=1, message=TextSentimentPredictionInstance, - ) - - prediction = proto.Field(proto.MESSAGE, number=2, message=Prediction,) + sentiment = proto.Field(proto.INT32, number=1) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py index f76b51899b..99fa365b47 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_action_recognition.py @@ -38,21 +38,21 @@ class VideoActionRecognitionPredictionResult(proto.Message): display_name (str): The display name of the AnnotationSpec that had been identified. - time_segment_start (~.duration.Duration): + time_segment_start (google.protobuf.duration_pb2.Duration): The beginning, inclusive, of the video's time segment in which the AnnotationSpec has been identified. Expressed as a number of seconds as measured from the start of the video, with fractions up to a microsecond precision, and with "s" appended at the end. - time_segment_end (~.duration.Duration): + time_segment_end (google.protobuf.duration_pb2.Duration): The end, exclusive, of the video's time segment in which the AnnotationSpec has been identified. Expressed as a number of seconds as measured from the start of the video, with fractions up to a microsecond precision, and with "s" appended at the end. - confidence (~.wrappers.FloatValue): + confidence (google.protobuf.wrappers_pb2.FloatValue): The Model's confidence in correction of this prediction, higher value means higher confidence. diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py index 469023b122..3fca68fe64 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_classification.py @@ -44,7 +44,7 @@ class VideoClassificationPredictionResult(proto.Message): will be one of - segment-classification - shot-classification - one-sec-interval-classification - time_segment_start (~.duration.Duration): + time_segment_start (google.protobuf.duration_pb2.Duration): The beginning, inclusive, of the video's time segment in which the AnnotationSpec has been identified. Expressed as a number of seconds as @@ -55,7 +55,7 @@ class VideoClassificationPredictionResult(proto.Message): equals the original 'timeSegmentStart' from the input instance, for other types it is the start of a shot or a 1 second interval respectively. - time_segment_end (~.duration.Duration): + time_segment_end (google.protobuf.duration_pb2.Duration): The end, exclusive, of the video's time segment in which the AnnotationSpec has been identified. Expressed as a number of seconds as @@ -66,7 +66,7 @@ class VideoClassificationPredictionResult(proto.Message): equals the original 'timeSegmentEnd' from the input instance, for other types it is the end of a shot or a 1 second interval respectively. - confidence (~.wrappers.FloatValue): + confidence (google.protobuf.wrappers_pb2.FloatValue): The Model's confidence in correction of this prediction, higher value means higher confidence. diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py index 026f80a325..6fd431c0dd 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/video_object_tracking.py @@ -38,25 +38,25 @@ class VideoObjectTrackingPredictionResult(proto.Message): display_name (str): The display name of the AnnotationSpec that had been identified. - time_segment_start (~.duration.Duration): + time_segment_start (google.protobuf.duration_pb2.Duration): The beginning, inclusive, of the video's time segment in which the object instance has been detected. Expressed as a number of seconds as measured from the start of the video, with fractions up to a microsecond precision, and with "s" appended at the end. - time_segment_end (~.duration.Duration): + time_segment_end (google.protobuf.duration_pb2.Duration): The end, inclusive, of the video's time segment in which the object instance has been detected. Expressed as a number of seconds as measured from the start of the video, with fractions up to a microsecond precision, and with "s" appended at the end. - confidence (~.wrappers.FloatValue): + confidence (google.protobuf.wrappers_pb2.FloatValue): The Model's confidence in correction of this prediction, higher value means higher confidence. - frames (Sequence[~.video_object_tracking.VideoObjectTrackingPredictionResult.Frame]): + frames (Sequence[google.cloud.aiplatform.v1beta1.schema.predict.prediction_v1beta1.types.VideoObjectTrackingPredictionResult.Frame]): All of the frames of the video in which a single object instance has been detected. The bounding boxes in the frames identify the same @@ -70,19 +70,19 @@ class Frame(proto.Message): size, and the point 0,0 is in the top left of the frame. Attributes: - time_offset (~.duration.Duration): + time_offset (google.protobuf.duration_pb2.Duration): A time (frame) of a video in which the object has been detected. Expressed as a number of seconds as measured from the start of the video, with fractions up to a microsecond precision, and with "s" appended at the end. - x_min (~.wrappers.FloatValue): + x_min (google.protobuf.wrappers_pb2.FloatValue): The leftmost coordinate of the bounding box. - x_max (~.wrappers.FloatValue): + x_max (google.protobuf.wrappers_pb2.FloatValue): The rightmost coordinate of the bounding box. - y_min (~.wrappers.FloatValue): + y_min (google.protobuf.wrappers_pb2.FloatValue): The topmost coordinate of the bounding box. - y_max (~.wrappers.FloatValue): + y_max (google.protobuf.wrappers_pb2.FloatValue): The bottommost coordinate of the bounding box. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py index 6a0e7903b2..3853ca87a9 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/__init__.py @@ -66,7 +66,6 @@ AutoMlVideoObjectTrackingInputs, ) - __all__ = ( "ExportEvaluatedDataItemsConfig", "AutoMlForecasting", diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py index 710793c9a7..34f700f8af 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_forecasting.py @@ -38,9 +38,9 @@ class AutoMlForecasting(proto.Message): Model. Attributes: - inputs (~.automl_forecasting.AutoMlForecastingInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs): The input parameters of this TrainingJob. - metadata (~.automl_forecasting.AutoMlForecastingMetadata): + metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingMetadata): The metadata information. """ @@ -64,7 +64,7 @@ class AutoMlForecastingInputs(proto.Message): time_column (str): The name of the column that identifies time order in the time series. - transformations (Sequence[~.automl_forecasting.AutoMlForecastingInputs.Transformation]): + transformations (Sequence[google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation]): Each transformation will apply transform function to given input column. And the result will be used for training. When creating @@ -78,14 +78,14 @@ class AutoMlForecastingInputs(proto.Message): function over the validation set. The supported optimization objectives: - "minimize-rmse" (default) - Minimize root- + "minimize-rmse" (default) - Minimize root- mean-squared error (RMSE). "minimize-mae" - Minimize mean-absolute error (MAE). "minimize- rmsle" - Minimize root-mean-squared log error (RMSLE). "minimize-rmspe" - Minimize root- mean-squared percentage error (RMSPE). "minimize-wape-mae" - Minimize the combination - of weighted absolute percentage error (WAPE) + of weighted absolute percentage error (WAPE) and mean-absolute-error (MAE). train_budget_milli_node_hours (int): Required. The train budget of creating this @@ -131,7 +131,7 @@ class AutoMlForecastingInputs(proto.Message): contains information for the given entity (identified by the key column) that is known for the past and the future - period (~.automl_forecasting.AutoMlForecastingInputs.Period): + period (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Period): Expected difference in time granularity between rows in the data. If it is not set, the period is inferred from data. @@ -152,7 +152,7 @@ class AutoMlForecastingInputs(proto.Message): sequence, where each period is one unit of granularity as defined by the ``period``. Default value 0 means that it lets algorithm to define the value. Inclusive. - export_evaluated_data_items_config (~.gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig): + export_evaluated_data_items_config (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.ExportEvaluatedDataItemsConfig): Configuration for exporting test set predictions to a BigQuery table. If this configuration is absent, then the export is not @@ -163,21 +163,21 @@ class Transformation(proto.Message): r""" Attributes: - auto (~.automl_forecasting.AutoMlForecastingInputs.Transformation.AutoTransformation): + auto (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.AutoTransformation): - numeric (~.automl_forecasting.AutoMlForecastingInputs.Transformation.NumericTransformation): + numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.NumericTransformation): - categorical (~.automl_forecasting.AutoMlForecastingInputs.Transformation.CategoricalTransformation): + categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.CategoricalTransformation): - timestamp (~.automl_forecasting.AutoMlForecastingInputs.Transformation.TimestampTransformation): + timestamp (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.TimestampTransformation): - text (~.automl_forecasting.AutoMlForecastingInputs.Transformation.TextTransformation): + text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.TextTransformation): - repeated_numeric (~.automl_forecasting.AutoMlForecastingInputs.Transformation.NumericArrayTransformation): + repeated_numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.NumericArrayTransformation): - repeated_categorical (~.automl_forecasting.AutoMlForecastingInputs.Transformation.CategoricalArrayTransformation): + repeated_categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.CategoricalArrayTransformation): - repeated_text (~.automl_forecasting.AutoMlForecastingInputs.Transformation.TextArrayTransformation): + repeated_text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlForecastingInputs.Transformation.TextArrayTransformation): """ @@ -418,11 +418,11 @@ class Period(proto.Message): unit (str): The time granularity unit of this time period. The supported unit are: - "hour" - "day" - "week" - "month" - "year". + "hour" + "day" + "week" + "month" + "year". quantity (int): The number of units per period, e.g. 3 weeks or 2 months. diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py index 0ee0394192..8ee27076d2 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_classification.py @@ -33,9 +33,9 @@ class AutoMlImageClassification(proto.Message): Classification Model. Attributes: - inputs (~.automl_image_classification.AutoMlImageClassificationInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageClassificationInputs): The input parameters of this TrainingJob. - metadata (~.automl_image_classification.AutoMlImageClassificationMetadata): + metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageClassificationMetadata): The metadata information. """ @@ -52,7 +52,7 @@ class AutoMlImageClassificationInputs(proto.Message): r""" Attributes: - model_type (~.automl_image_classification.AutoMlImageClassificationInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageClassificationInputs.ModelType): base_model_id (str): The ID of the ``base`` model. If it is specified, the new @@ -122,7 +122,7 @@ class AutoMlImageClassificationMetadata(proto.Message): value in this field means 1 node hour. Guaranteed to not exceed inputs.budgetMilliNodeHours. - successful_stop_reason (~.automl_image_classification.AutoMlImageClassificationMetadata.SuccessfulStopReason): + successful_stop_reason (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageClassificationMetadata.SuccessfulStopReason): For successful job completions, this is the reason why the job has finished. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py index 3fb9d3ae1d..512e35ed1d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_object_detection.py @@ -33,9 +33,9 @@ class AutoMlImageObjectDetection(proto.Message): Detection Model. Attributes: - inputs (~.automl_image_object_detection.AutoMlImageObjectDetectionInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageObjectDetectionInputs): The input parameters of this TrainingJob. - metadata (~.automl_image_object_detection.AutoMlImageObjectDetectionMetadata): + metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageObjectDetectionMetadata): The metadata information """ @@ -52,7 +52,7 @@ class AutoMlImageObjectDetectionInputs(proto.Message): r""" Attributes: - model_type (~.automl_image_object_detection.AutoMlImageObjectDetectionInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageObjectDetectionInputs.ModelType): budget_milli_node_hours (int): The training budget of creating this model, expressed in @@ -107,7 +107,7 @@ class AutoMlImageObjectDetectionMetadata(proto.Message): value in this field means 1 node hour. Guaranteed to not exceed inputs.budgetMilliNodeHours. - successful_stop_reason (~.automl_image_object_detection.AutoMlImageObjectDetectionMetadata.SuccessfulStopReason): + successful_stop_reason (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageObjectDetectionMetadata.SuccessfulStopReason): For successful job completions, this is the reason why the job has finished. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py index 0fa3788b11..22c199e7f5 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_image_segmentation.py @@ -33,9 +33,9 @@ class AutoMlImageSegmentation(proto.Message): Segmentation Model. Attributes: - inputs (~.automl_image_segmentation.AutoMlImageSegmentationInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageSegmentationInputs): The input parameters of this TrainingJob. - metadata (~.automl_image_segmentation.AutoMlImageSegmentationMetadata): + metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageSegmentationMetadata): The metadata information. """ @@ -52,7 +52,7 @@ class AutoMlImageSegmentationInputs(proto.Message): r""" Attributes: - model_type (~.automl_image_segmentation.AutoMlImageSegmentationInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageSegmentationInputs.ModelType): budget_milli_node_hours (int): The training budget of creating this model, expressed in @@ -100,7 +100,7 @@ class AutoMlImageSegmentationMetadata(proto.Message): value in this field means 1 node hour. Guaranteed to not exceed inputs.budgetMilliNodeHours. - successful_stop_reason (~.automl_image_segmentation.AutoMlImageSegmentationMetadata.SuccessfulStopReason): + successful_stop_reason (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlImageSegmentationMetadata.SuccessfulStopReason): For successful job completions, this is the reason why the job has finished. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py index f924979bd6..19c43929e8 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_tables.py @@ -33,9 +33,9 @@ class AutoMlTables(proto.Message): r"""A TrainingJob that trains and uploads an AutoML Tables Model. Attributes: - inputs (~.automl_tables.AutoMlTablesInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs): The input parameters of this TrainingJob. - metadata (~.automl_tables.AutoMlTablesMetadata): + metadata (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesMetadata): The metadata information. """ @@ -61,7 +61,7 @@ class AutoMlTablesInputs(proto.Message): produce. "classification" - Predict one out of multiple target values is picked for each row. - "regression" - Predict a value based on its + "regression" - Predict a value based on its relation to other values. This type is available only to columns that contain semantically numeric values, i.e. integers or @@ -70,7 +70,7 @@ class AutoMlTablesInputs(proto.Message): target_column (str): The column name of the target column that the model is to predict. - transformations (Sequence[~.automl_tables.AutoMlTablesInputs.Transformation]): + transformations (Sequence[google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation]): Each transformation will apply transform function to given input column. And the result will be used for training. When creating @@ -87,11 +87,11 @@ class AutoMlTablesInputs(proto.Message): the prediction type. If the field is not set, a default objective function is used. classification (binary): - "maximize-au-roc" (default) - Maximize the + "maximize-au-roc" (default) - Maximize the area under the receiver operating characteristic (ROC) curve. "minimize-log-loss" - Minimize log loss. - "maximize-au-prc" - Maximize the area under + "maximize-au-prc" - Maximize the area under the precision-recall curve. "maximize- precision-at-recall" - Maximize precision for a specified @@ -99,10 +99,10 @@ class AutoMlTablesInputs(proto.Message): Maximize recall for a specified precision value. classification (multi-class): - "minimize-log-loss" (default) - Minimize log + "minimize-log-loss" (default) - Minimize log loss. regression: - "minimize-rmse" (default) - Minimize root- + "minimize-rmse" (default) - Minimize root- mean-squared error (RMSE). "minimize-mae" - Minimize mean-absolute error (MAE). "minimize- rmsle" - Minimize root-mean-squared log error @@ -140,7 +140,7 @@ class AutoMlTablesInputs(proto.Message): for training. If weight column field is not set, then all rows are assumed to have equal weight of 1. - export_evaluated_data_items_config (~.gcastd_export_evaluated_data_items_config.ExportEvaluatedDataItemsConfig): + export_evaluated_data_items_config (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.ExportEvaluatedDataItemsConfig): Configuration for exporting test set predictions to a BigQuery table. If this configuration is absent, then the export is not @@ -151,21 +151,21 @@ class Transformation(proto.Message): r""" Attributes: - auto (~.automl_tables.AutoMlTablesInputs.Transformation.AutoTransformation): + auto (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.AutoTransformation): - numeric (~.automl_tables.AutoMlTablesInputs.Transformation.NumericTransformation): + numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.NumericTransformation): - categorical (~.automl_tables.AutoMlTablesInputs.Transformation.CategoricalTransformation): + categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.CategoricalTransformation): - timestamp (~.automl_tables.AutoMlTablesInputs.Transformation.TimestampTransformation): + timestamp (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TimestampTransformation): - text (~.automl_tables.AutoMlTablesInputs.Transformation.TextTransformation): + text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextTransformation): - repeated_numeric (~.automl_tables.AutoMlTablesInputs.Transformation.NumericArrayTransformation): + repeated_numeric (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.NumericArrayTransformation): - repeated_categorical (~.automl_tables.AutoMlTablesInputs.Transformation.CategoricalArrayTransformation): + repeated_categorical (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.CategoricalArrayTransformation): - repeated_text (~.automl_tables.AutoMlTablesInputs.Transformation.TextArrayTransformation): + repeated_text (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTablesInputs.Transformation.TextArrayTransformation): """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py index ca75734600..9fe6b865c9 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_classification.py @@ -29,7 +29,7 @@ class AutoMlTextClassification(proto.Message): Classification Model. Attributes: - inputs (~.automl_text_classification.AutoMlTextClassificationInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTextClassificationInputs): The input parameters of this TrainingJob. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py index 336509af22..c7b1fc6dba 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_extraction.py @@ -29,7 +29,7 @@ class AutoMlTextExtraction(proto.Message): Extraction Model. Attributes: - inputs (~.automl_text_extraction.AutoMlTextExtractionInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTextExtractionInputs): The input parameters of this TrainingJob. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py index d5de97e2b2..8239b55fdf 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_text_sentiment.py @@ -29,7 +29,7 @@ class AutoMlTextSentiment(proto.Message): Sentiment Model. Attributes: - inputs (~.automl_text_sentiment.AutoMlTextSentimentInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlTextSentimentInputs): The input parameters of this TrainingJob. """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py index d6969d93c6..66448faf01 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_action_recognition.py @@ -29,7 +29,7 @@ class AutoMlVideoActionRecognition(proto.Message): Recognition Model. Attributes: - inputs (~.automl_video_action_recognition.AutoMlVideoActionRecognitionInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoActionRecognitionInputs): The input parameters of this TrainingJob. """ @@ -42,7 +42,7 @@ class AutoMlVideoActionRecognitionInputs(proto.Message): r""" Attributes: - model_type (~.automl_video_action_recognition.AutoMlVideoActionRecognitionInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoActionRecognitionInputs.ModelType): """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py index 3164544d47..51195eb327 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_classification.py @@ -29,7 +29,7 @@ class AutoMlVideoClassification(proto.Message): Classification Model. Attributes: - inputs (~.automl_video_classification.AutoMlVideoClassificationInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoClassificationInputs): The input parameters of this TrainingJob. """ @@ -42,7 +42,7 @@ class AutoMlVideoClassificationInputs(proto.Message): r""" Attributes: - model_type (~.automl_video_classification.AutoMlVideoClassificationInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoClassificationInputs.ModelType): """ diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py index 0fd8c7ec7a..328e266a3b 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/types/automl_video_object_tracking.py @@ -29,7 +29,7 @@ class AutoMlVideoObjectTracking(proto.Message): ObjectTracking Model. Attributes: - inputs (~.automl_video_object_tracking.AutoMlVideoObjectTrackingInputs): + inputs (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoObjectTrackingInputs): The input parameters of this TrainingJob. """ @@ -42,7 +42,7 @@ class AutoMlVideoObjectTrackingInputs(proto.Message): r""" Attributes: - model_type (~.automl_video_object_tracking.AutoMlVideoObjectTrackingInputs.ModelType): + model_type (google.cloud.aiplatform.v1beta1.schema.trainingjob.definition_v1beta1.types.AutoMlVideoObjectTrackingInputs.ModelType): """ diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py new file mode 100644 index 0000000000..1b0c76e834 --- /dev/null +++ b/google/cloud/aiplatform_v1/__init__.py @@ -0,0 +1,345 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .services.dataset_service import DatasetServiceClient +from .services.endpoint_service import EndpointServiceClient +from .services.job_service import JobServiceClient +from .services.migration_service import MigrationServiceClient +from .services.model_service import ModelServiceClient +from .services.pipeline_service import PipelineServiceClient +from .services.prediction_service import PredictionServiceClient +from .services.specialist_pool_service import SpecialistPoolServiceClient +from .types.accelerator_type import AcceleratorType +from .types.annotation import Annotation +from .types.annotation_spec import AnnotationSpec +from .types.batch_prediction_job import BatchPredictionJob +from .types.completion_stats import CompletionStats +from .types.custom_job import ContainerSpec +from .types.custom_job import CustomJob +from .types.custom_job import CustomJobSpec +from .types.custom_job import PythonPackageSpec +from .types.custom_job import Scheduling +from .types.custom_job import WorkerPoolSpec +from .types.data_item import DataItem +from .types.data_labeling_job import ActiveLearningConfig +from .types.data_labeling_job import DataLabelingJob +from .types.data_labeling_job import SampleConfig +from .types.data_labeling_job import TrainingConfig +from .types.dataset import Dataset +from .types.dataset import ExportDataConfig +from .types.dataset import ImportDataConfig +from .types.dataset_service import CreateDatasetOperationMetadata +from .types.dataset_service import CreateDatasetRequest +from .types.dataset_service import DeleteDatasetRequest +from .types.dataset_service import ExportDataOperationMetadata +from .types.dataset_service import ExportDataRequest +from .types.dataset_service import ExportDataResponse +from .types.dataset_service import GetAnnotationSpecRequest +from .types.dataset_service import GetDatasetRequest +from .types.dataset_service import ImportDataOperationMetadata +from .types.dataset_service import ImportDataRequest +from .types.dataset_service import ImportDataResponse +from .types.dataset_service import ListAnnotationsRequest +from .types.dataset_service import ListAnnotationsResponse +from .types.dataset_service import ListDataItemsRequest +from .types.dataset_service import ListDataItemsResponse +from .types.dataset_service import ListDatasetsRequest +from .types.dataset_service import ListDatasetsResponse +from .types.dataset_service import UpdateDatasetRequest +from .types.deployed_model_ref import DeployedModelRef +from .types.encryption_spec import EncryptionSpec +from .types.endpoint import DeployedModel +from .types.endpoint import Endpoint +from .types.endpoint_service import CreateEndpointOperationMetadata +from .types.endpoint_service import CreateEndpointRequest +from .types.endpoint_service import DeleteEndpointRequest +from .types.endpoint_service import DeployModelOperationMetadata +from .types.endpoint_service import DeployModelRequest +from .types.endpoint_service import DeployModelResponse +from .types.endpoint_service import GetEndpointRequest +from .types.endpoint_service import ListEndpointsRequest +from .types.endpoint_service import ListEndpointsResponse +from .types.endpoint_service import UndeployModelOperationMetadata +from .types.endpoint_service import UndeployModelRequest +from .types.endpoint_service import UndeployModelResponse +from .types.endpoint_service import UpdateEndpointRequest +from .types.env_var import EnvVar +from .types.hyperparameter_tuning_job import HyperparameterTuningJob +from .types.io import BigQueryDestination +from .types.io import BigQuerySource +from .types.io import ContainerRegistryDestination +from .types.io import GcsDestination +from .types.io import GcsSource +from .types.job_service import CancelBatchPredictionJobRequest +from .types.job_service import CancelCustomJobRequest +from .types.job_service import CancelDataLabelingJobRequest +from .types.job_service import CancelHyperparameterTuningJobRequest +from .types.job_service import CreateBatchPredictionJobRequest +from .types.job_service import CreateCustomJobRequest +from .types.job_service import CreateDataLabelingJobRequest +from .types.job_service import CreateHyperparameterTuningJobRequest +from .types.job_service import DeleteBatchPredictionJobRequest +from .types.job_service import DeleteCustomJobRequest +from .types.job_service import DeleteDataLabelingJobRequest +from .types.job_service import DeleteHyperparameterTuningJobRequest +from .types.job_service import GetBatchPredictionJobRequest +from .types.job_service import GetCustomJobRequest +from .types.job_service import GetDataLabelingJobRequest +from .types.job_service import GetHyperparameterTuningJobRequest +from .types.job_service import ListBatchPredictionJobsRequest +from .types.job_service import ListBatchPredictionJobsResponse +from .types.job_service import ListCustomJobsRequest +from .types.job_service import ListCustomJobsResponse +from .types.job_service import ListDataLabelingJobsRequest +from .types.job_service import ListDataLabelingJobsResponse +from .types.job_service import ListHyperparameterTuningJobsRequest +from .types.job_service import ListHyperparameterTuningJobsResponse +from .types.job_state import JobState +from .types.machine_resources import AutomaticResources +from .types.machine_resources import BatchDedicatedResources +from .types.machine_resources import DedicatedResources +from .types.machine_resources import DiskSpec +from .types.machine_resources import MachineSpec +from .types.machine_resources import ResourcesConsumed +from .types.manual_batch_tuning_parameters import ManualBatchTuningParameters +from .types.migratable_resource import MigratableResource +from .types.migration_service import BatchMigrateResourcesOperationMetadata +from .types.migration_service import BatchMigrateResourcesRequest +from .types.migration_service import BatchMigrateResourcesResponse +from .types.migration_service import MigrateResourceRequest +from .types.migration_service import MigrateResourceResponse +from .types.migration_service import SearchMigratableResourcesRequest +from .types.migration_service import SearchMigratableResourcesResponse +from .types.model import Model +from .types.model import ModelContainerSpec +from .types.model import Port +from .types.model import PredictSchemata +from .types.model_evaluation import ModelEvaluation +from .types.model_evaluation_slice import ModelEvaluationSlice +from .types.model_service import DeleteModelRequest +from .types.model_service import ExportModelOperationMetadata +from .types.model_service import ExportModelRequest +from .types.model_service import ExportModelResponse +from .types.model_service import GetModelEvaluationRequest +from .types.model_service import GetModelEvaluationSliceRequest +from .types.model_service import GetModelRequest +from .types.model_service import ListModelEvaluationSlicesRequest +from .types.model_service import ListModelEvaluationSlicesResponse +from .types.model_service import ListModelEvaluationsRequest +from .types.model_service import ListModelEvaluationsResponse +from .types.model_service import ListModelsRequest +from .types.model_service import ListModelsResponse +from .types.model_service import UpdateModelRequest +from .types.model_service import UploadModelOperationMetadata +from .types.model_service import UploadModelRequest +from .types.model_service import UploadModelResponse +from .types.operation import DeleteOperationMetadata +from .types.operation import GenericOperationMetadata +from .types.pipeline_service import CancelTrainingPipelineRequest +from .types.pipeline_service import CreateTrainingPipelineRequest +from .types.pipeline_service import DeleteTrainingPipelineRequest +from .types.pipeline_service import GetTrainingPipelineRequest +from .types.pipeline_service import ListTrainingPipelinesRequest +from .types.pipeline_service import ListTrainingPipelinesResponse +from .types.pipeline_state import PipelineState +from .types.prediction_service import PredictRequest +from .types.prediction_service import PredictResponse +from .types.specialist_pool import SpecialistPool +from .types.specialist_pool_service import CreateSpecialistPoolOperationMetadata +from .types.specialist_pool_service import CreateSpecialistPoolRequest +from .types.specialist_pool_service import DeleteSpecialistPoolRequest +from .types.specialist_pool_service import GetSpecialistPoolRequest +from .types.specialist_pool_service import ListSpecialistPoolsRequest +from .types.specialist_pool_service import ListSpecialistPoolsResponse +from .types.specialist_pool_service import UpdateSpecialistPoolOperationMetadata +from .types.specialist_pool_service import UpdateSpecialistPoolRequest +from .types.study import Measurement +from .types.study import StudySpec +from .types.study import Trial +from .types.training_pipeline import FilterSplit +from .types.training_pipeline import FractionSplit +from .types.training_pipeline import InputDataConfig +from .types.training_pipeline import PredefinedSplit +from .types.training_pipeline import TimestampSplit +from .types.training_pipeline import TrainingPipeline +from .types.user_action_reference import UserActionReference + + +__all__ = ( + "AcceleratorType", + "ActiveLearningConfig", + "Annotation", + "AnnotationSpec", + "AutomaticResources", + "BatchDedicatedResources", + "BatchMigrateResourcesOperationMetadata", + "BatchMigrateResourcesRequest", + "BatchMigrateResourcesResponse", + "BatchPredictionJob", + "BigQueryDestination", + "BigQuerySource", + "CancelBatchPredictionJobRequest", + "CancelCustomJobRequest", + "CancelDataLabelingJobRequest", + "CancelHyperparameterTuningJobRequest", + "CancelTrainingPipelineRequest", + "CompletionStats", + "ContainerRegistryDestination", + "ContainerSpec", + "CreateBatchPredictionJobRequest", + "CreateCustomJobRequest", + "CreateDataLabelingJobRequest", + "CreateDatasetOperationMetadata", + "CreateDatasetRequest", + "CreateEndpointOperationMetadata", + "CreateEndpointRequest", + "CreateHyperparameterTuningJobRequest", + "CreateSpecialistPoolOperationMetadata", + "CreateSpecialistPoolRequest", + "CreateTrainingPipelineRequest", + "CustomJob", + "CustomJobSpec", + "DataItem", + "DataLabelingJob", + "Dataset", + "DatasetServiceClient", + "DedicatedResources", + "DeleteBatchPredictionJobRequest", + "DeleteCustomJobRequest", + "DeleteDataLabelingJobRequest", + "DeleteDatasetRequest", + "DeleteEndpointRequest", + "DeleteHyperparameterTuningJobRequest", + "DeleteModelRequest", + "DeleteOperationMetadata", + "DeleteSpecialistPoolRequest", + "DeleteTrainingPipelineRequest", + "DeployModelOperationMetadata", + "DeployModelRequest", + "DeployModelResponse", + "DeployedModel", + "DeployedModelRef", + "DiskSpec", + "EncryptionSpec", + "Endpoint", + "EndpointServiceClient", + "EnvVar", + "ExportDataConfig", + "ExportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportModelOperationMetadata", + "ExportModelRequest", + "ExportModelResponse", + "FilterSplit", + "FractionSplit", + "GcsDestination", + "GcsSource", + "GenericOperationMetadata", + "GetAnnotationSpecRequest", + "GetBatchPredictionJobRequest", + "GetCustomJobRequest", + "GetDataLabelingJobRequest", + "GetDatasetRequest", + "GetEndpointRequest", + "GetHyperparameterTuningJobRequest", + "GetModelEvaluationRequest", + "GetModelEvaluationSliceRequest", + "GetModelRequest", + "GetSpecialistPoolRequest", + "GetTrainingPipelineRequest", + "HyperparameterTuningJob", + "ImportDataConfig", + "ImportDataOperationMetadata", + "ImportDataRequest", + "ImportDataResponse", + "InputDataConfig", + "JobServiceClient", + "JobState", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "ListDataItemsRequest", + "ListDataItemsResponse", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "ListDatasetsRequest", + "ListDatasetsResponse", + "ListEndpointsRequest", + "ListEndpointsResponse", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "MachineSpec", + "ManualBatchTuningParameters", + "Measurement", + "MigratableResource", + "MigrateResourceRequest", + "MigrateResourceResponse", + "MigrationServiceClient", + "Model", + "ModelContainerSpec", + "ModelEvaluation", + "ModelEvaluationSlice", + "ModelServiceClient", + "PipelineServiceClient", + "PipelineState", + "Port", + "PredefinedSplit", + "PredictRequest", + "PredictResponse", + "PredictSchemata", + "PredictionServiceClient", + "PythonPackageSpec", + "ResourcesConsumed", + "SampleConfig", + "Scheduling", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "SpecialistPool", + "StudySpec", + "TimestampSplit", + "TrainingConfig", + "TrainingPipeline", + "Trial", + "UndeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UpdateDatasetRequest", + "UpdateEndpointRequest", + "UpdateModelRequest", + "UpdateSpecialistPoolOperationMetadata", + "UpdateSpecialistPoolRequest", + "UploadModelOperationMetadata", + "UploadModelRequest", + "UploadModelResponse", + "UserActionReference", + "WorkerPoolSpec", + "SpecialistPoolServiceClient", +) diff --git a/google/cloud/aiplatform_v1/py.typed b/google/cloud/aiplatform_v1/py.typed new file mode 100644 index 0000000000..228f1c51c6 --- /dev/null +++ b/google/cloud/aiplatform_v1/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-cloud-aiplatform package uses inline types. diff --git a/google/cloud/aiplatform_v1/services/__init__.py b/google/cloud/aiplatform_v1/services/__init__.py new file mode 100644 index 0000000000..42ffdf2bc4 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/google/cloud/aiplatform_v1/services/dataset_service/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py new file mode 100644 index 0000000000..597f654cb9 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import DatasetServiceClient +from .async_client import DatasetServiceAsyncClient + +__all__ = ( + "DatasetServiceClient", + "DatasetServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py new file mode 100644 index 0000000000..d5b56b54f5 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/async_client.py @@ -0,0 +1,1044 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.dataset_service import pagers +from google.cloud.aiplatform_v1.types import annotation +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import data_item +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport +from .client import DatasetServiceClient + + +class DatasetServiceAsyncClient: + """""" + + _client: DatasetServiceClient + + DEFAULT_ENDPOINT = DatasetServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DatasetServiceClient.DEFAULT_MTLS_ENDPOINT + + annotation_path = staticmethod(DatasetServiceClient.annotation_path) + parse_annotation_path = staticmethod(DatasetServiceClient.parse_annotation_path) + annotation_spec_path = staticmethod(DatasetServiceClient.annotation_spec_path) + parse_annotation_spec_path = staticmethod( + DatasetServiceClient.parse_annotation_spec_path + ) + data_item_path = staticmethod(DatasetServiceClient.data_item_path) + parse_data_item_path = staticmethod(DatasetServiceClient.parse_data_item_path) + dataset_path = staticmethod(DatasetServiceClient.dataset_path) + parse_dataset_path = staticmethod(DatasetServiceClient.parse_dataset_path) + + common_billing_account_path = staticmethod( + DatasetServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DatasetServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(DatasetServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + DatasetServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + DatasetServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DatasetServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(DatasetServiceClient.common_project_path) + parse_common_project_path = staticmethod( + DatasetServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(DatasetServiceClient.common_location_path) + parse_common_location_path = staticmethod( + DatasetServiceClient.parse_common_location_path + ) + + from_service_account_info = DatasetServiceClient.from_service_account_info + from_service_account_file = DatasetServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(DatasetServiceClient).get_transport_class, type(DatasetServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, DatasetServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the dataset service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.DatasetServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = DatasetServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateDatasetRequest`): + The request object. Request message for + ``DatasetService.CreateDataset``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Dataset in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + dataset (:class:`google.cloud.aiplatform_v1.types.Dataset`): + Required. The Dataset to create. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.Dataset` A + collection of DataItems and Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.CreateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_dataset.Dataset, + metadata_type=dataset_service.CreateDatasetOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: + r"""Gets a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetDatasetRequest`): + The request object. Request message for + ``DatasetService.GetDataset``. + name (:class:`str`): + Required. The name of the Dataset + resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.GetDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: + r"""Updates a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UpdateDatasetRequest`): + The request object. Request message for + ``DatasetService.UpdateDataset``. + dataset (:class:`google.cloud.aiplatform_v1.types.Dataset`): + Required. The Dataset which replaces + the resource on the server. + + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + `FieldMask `__. + Updatable fields: + + - ``display_name`` + - ``description`` + - ``labels`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.UpdateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsAsyncPager: + r"""Lists Datasets in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListDatasetsRequest`): + The request object. Request message for + ``DatasetService.ListDatasets``. + parent (:class:`str`): + Required. The name of the Dataset's parent resource. + Format: ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListDatasetsAsyncPager: + Response message for + ``DatasetService.ListDatasets``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListDatasetsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_datasets, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDatasetsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteDatasetRequest`): + The request object. Request message for + ``DatasetService.DeleteDataset``. + name (:class:`str`): + Required. The resource name of the Dataset to delete. + Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.DeleteDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_dataset, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Imports data into a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ImportDataRequest`): + The request object. Request message for + ``DatasetService.ImportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + import_configs (:class:`Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig]`): + Required. The desired input + locations. The contents of all input + locations will be imported in one batch. + + This corresponds to the ``import_configs`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ImportDataResponse` + Response message for + ``DatasetService.ImportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ImportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + if import_configs: + request.import_configs.extend(import_configs) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.import_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ImportDataResponse, + metadata_type=dataset_service.ImportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports data from a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ExportDataRequest`): + The request object. Request message for + ``DatasetService.ExportData``. + name (:class:`str`): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + export_config (:class:`google.cloud.aiplatform_v1.types.ExportDataConfig`): + Required. The desired output + location. + + This corresponds to the ``export_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ExportDataResponse` + Response message for + ``DatasetService.ExportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ExportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_data, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + dataset_service.ExportDataResponse, + metadata_type=dataset_service.ExportDataOperationMetadata, + ) + + # Done; return the response. + return response + + async def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsAsyncPager: + r"""Lists DataItems in a Dataset. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListDataItemsRequest`): + The request object. Request message for + ``DatasetService.ListDataItems``. + parent (:class:`str`): + Required. The resource name of the Dataset to list + DataItems from. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListDataItemsAsyncPager: + Response message for + ``DatasetService.ListDataItems``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListDataItemsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_items, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataItemsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: + r"""Gets an AnnotationSpec. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest`): + The request object. Request message for + ``DatasetService.GetAnnotationSpec``. + name (:class:`str`): + Required. The name of the AnnotationSpec resource. + Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.AnnotationSpec: + Identifies a concept with which + DataItems may be annotated with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.GetAnnotationSpecRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_annotation_spec, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsAsyncPager: + r"""Lists Annotations belongs to a dataitem + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListAnnotationsRequest`): + The request object. Request message for + ``DatasetService.ListAnnotations``. + parent (:class:`str`): + Required. The resource name of the DataItem to list + Annotations from. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListAnnotationsAsyncPager: + Response message for + ``DatasetService.ListAnnotations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = dataset_service.ListAnnotationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_annotations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListAnnotationsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("DatasetServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/client.py b/google/cloud/aiplatform_v1/services/dataset_service/client.py new file mode 100644 index 0000000000..e545dbe56e --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/client.py @@ -0,0 +1,1309 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.dataset_service import pagers +from google.cloud.aiplatform_v1.types import annotation +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import data_item +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import DatasetServiceGrpcTransport +from .transports.grpc_asyncio import DatasetServiceGrpcAsyncIOTransport + + +class DatasetServiceClientMeta(type): + """Metaclass for the DatasetService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DatasetServiceTransport]] + _transport_registry["grpc"] = DatasetServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[DatasetServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class DatasetServiceClient(metaclass=DatasetServiceClientMeta): + """""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DatasetServiceTransport: + """Return the transport used by the client instance. + + Returns: + DatasetServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotation_path( + project: str, location: str, dataset: str, data_item: str, annotation: str, + ) -> str: + """Return a fully-qualified annotation string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + + @staticmethod + def parse_annotation_path(path: str) -> Dict[str, str]: + """Parse a annotation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)/annotations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def annotation_spec_path( + project: str, location: str, dataset: str, annotation_spec: str, + ) -> str: + """Return a fully-qualified annotation_spec string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + + @staticmethod + def parse_annotation_spec_path(path: str) -> Dict[str, str]: + """Parse a annotation_spec path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/annotationSpecs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def data_item_path( + project: str, location: str, dataset: str, data_item: str, + ) -> str: + """Return a fully-qualified data_item string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) + + @staticmethod + def parse_data_item_path(path: str) -> Dict[str, str]: + """Parse a data_item path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)/dataItems/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, DatasetServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the dataset service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, DatasetServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, DatasetServiceTransport): + # transport is a DatasetServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_dataset( + self, + request: dataset_service.CreateDatasetRequest = None, + *, + parent: str = None, + dataset: gca_dataset.Dataset = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Creates a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.CreateDatasetRequest): + The request object. Request message for + ``DatasetService.CreateDataset``. + parent (str): + Required. The resource name of the Location to create + the Dataset in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + dataset (google.cloud.aiplatform_v1.types.Dataset): + Required. The Dataset to create. + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.Dataset` A + collection of DataItems and Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, dataset]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.CreateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.CreateDatasetRequest): + request = dataset_service.CreateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if dataset is not None: + request.dataset = dataset + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_dataset.Dataset, + metadata_type=dataset_service.CreateDatasetOperationMetadata, + ) + + # Done; return the response. + return response + + def get_dataset( + self, + request: dataset_service.GetDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> dataset.Dataset: + r"""Gets a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.GetDatasetRequest): + The request object. Request message for + ``DatasetService.GetDataset``. + name (str): + Required. The name of the Dataset + resource. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetDatasetRequest): + request = dataset_service.GetDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def update_dataset( + self, + request: dataset_service.UpdateDatasetRequest = None, + *, + dataset: gca_dataset.Dataset = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_dataset.Dataset: + r"""Updates a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.UpdateDatasetRequest): + The request object. Request message for + ``DatasetService.UpdateDataset``. + dataset (google.cloud.aiplatform_v1.types.Dataset): + Required. The Dataset which replaces + the resource on the server. + + This corresponds to the ``dataset`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + `FieldMask `__. + Updatable fields: + + - ``display_name`` + - ``description`` + - ``labels`` + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Dataset: + A collection of DataItems and + Annotations on them. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([dataset, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.UpdateDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.UpdateDatasetRequest): + request = dataset_service.UpdateDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if dataset is not None: + request.dataset = dataset + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("dataset.name", request.dataset.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_datasets( + self, + request: dataset_service.ListDatasetsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: + r"""Lists Datasets in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListDatasetsRequest): + The request object. Request message for + ``DatasetService.ListDatasets``. + parent (str): + Required. The name of the Dataset's parent resource. + Format: ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListDatasetsPager: + Response message for + ``DatasetService.ListDatasets``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDatasetsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDatasetsRequest): + request = dataset_service.ListDatasetsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_datasets] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListDatasetsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_dataset( + self, + request: dataset_service.DeleteDatasetRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteDatasetRequest): + The request object. Request message for + ``DatasetService.DeleteDataset``. + name (str): + Required. The resource name of the Dataset to delete. + Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.DeleteDatasetRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.DeleteDatasetRequest): + request = dataset_service.DeleteDatasetRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_dataset] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def import_data( + self, + request: dataset_service.ImportDataRequest = None, + *, + name: str = None, + import_configs: Sequence[dataset.ImportDataConfig] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Imports data into a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.ImportDataRequest): + The request object. Request message for + ``DatasetService.ImportData``. + name (str): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + import_configs (Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig]): + Required. The desired input + locations. The contents of all input + locations will be imported in one batch. + + This corresponds to the ``import_configs`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ImportDataResponse` + Response message for + ``DatasetService.ImportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, import_configs]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ImportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ImportDataRequest): + request = dataset_service.ImportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + if import_configs: + request.import_configs.extend(import_configs) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.import_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + dataset_service.ImportDataResponse, + metadata_type=dataset_service.ImportDataOperationMetadata, + ) + + # Done; return the response. + return response + + def export_data( + self, + request: dataset_service.ExportDataRequest = None, + *, + name: str = None, + export_config: dataset.ExportDataConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Exports data from a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.ExportDataRequest): + The request object. Request message for + ``DatasetService.ExportData``. + name (str): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + export_config (google.cloud.aiplatform_v1.types.ExportDataConfig): + Required. The desired output + location. + + This corresponds to the ``export_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ExportDataResponse` + Response message for + ``DatasetService.ExportData``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, export_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ExportDataRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ExportDataRequest): + request = dataset_service.ExportDataRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if export_config is not None: + request.export_config = export_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.export_data] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + dataset_service.ExportDataResponse, + metadata_type=dataset_service.ExportDataOperationMetadata, + ) + + # Done; return the response. + return response + + def list_data_items( + self, + request: dataset_service.ListDataItemsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataItemsPager: + r"""Lists DataItems in a Dataset. + + Args: + request (google.cloud.aiplatform_v1.types.ListDataItemsRequest): + The request object. Request message for + ``DatasetService.ListDataItems``. + parent (str): + Required. The resource name of the Dataset to list + DataItems from. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListDataItemsPager: + Response message for + ``DatasetService.ListDataItems``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListDataItemsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListDataItemsRequest): + request = dataset_service.ListDataItemsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_data_items] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListDataItemsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def get_annotation_spec( + self, + request: dataset_service.GetAnnotationSpecRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> annotation_spec.AnnotationSpec: + r"""Gets an AnnotationSpec. + + Args: + request (google.cloud.aiplatform_v1.types.GetAnnotationSpecRequest): + The request object. Request message for + ``DatasetService.GetAnnotationSpec``. + name (str): + Required. The name of the AnnotationSpec resource. + Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.AnnotationSpec: + Identifies a concept with which + DataItems may be annotated with. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.GetAnnotationSpecRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.GetAnnotationSpecRequest): + request = dataset_service.GetAnnotationSpecRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_annotation_spec] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_annotations( + self, + request: dataset_service.ListAnnotationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListAnnotationsPager: + r"""Lists Annotations belongs to a dataitem + + Args: + request (google.cloud.aiplatform_v1.types.ListAnnotationsRequest): + The request object. Request message for + ``DatasetService.ListAnnotations``. + parent (str): + Required. The resource name of the DataItem to list + Annotations from. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.dataset_service.pagers.ListAnnotationsPager: + Response message for + ``DatasetService.ListAnnotations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a dataset_service.ListAnnotationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, dataset_service.ListAnnotationsRequest): + request = dataset_service.ListAnnotationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_annotations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListAnnotationsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("DatasetServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py new file mode 100644 index 0000000000..f195ca3308 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/pagers.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import annotation +from google.cloud.aiplatform_v1.types import data_item +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset_service + + +class ListDatasetsPager: + """A pager for iterating through ``list_datasets`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDatasetsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``datasets`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListDatasets`` requests and continue to iterate + through the ``datasets`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDatasetsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., dataset_service.ListDatasetsResponse], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDatasetsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDatasetsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDatasetsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[dataset_service.ListDatasetsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[dataset.Dataset]: + for page in self.pages: + yield from page.datasets + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDatasetsAsyncPager: + """A pager for iterating through ``list_datasets`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDatasetsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``datasets`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDatasets`` requests and continue to iterate + through the ``datasets`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDatasetsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDatasetsResponse]], + request: dataset_service.ListDatasetsRequest, + response: dataset_service.ListDatasetsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDatasetsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDatasetsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDatasetsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDatasetsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[dataset.Dataset]: + async def async_generator(): + async for page in self.pages: + for response in page.datasets: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDataItemsPager: + """A pager for iterating through ``list_data_items`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDataItemsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``data_items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListDataItems`` requests and continue to iterate + through the ``data_items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDataItemsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., dataset_service.ListDataItemsResponse], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDataItemsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDataItemsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDataItemsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[dataset_service.ListDataItemsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[data_item.DataItem]: + for page in self.pages: + yield from page.data_items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDataItemsAsyncPager: + """A pager for iterating through ``list_data_items`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDataItemsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_items`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataItems`` requests and continue to iterate + through the ``data_items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDataItemsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListDataItemsResponse]], + request: dataset_service.ListDataItemsRequest, + response: dataset_service.ListDataItemsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDataItemsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDataItemsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListDataItemsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListDataItemsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_item.DataItem]: + async def async_generator(): + async for page in self.pages: + for response in page.data_items: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListAnnotationsPager: + """A pager for iterating through ``list_annotations`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListAnnotationsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``annotations`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListAnnotations`` requests and continue to iterate + through the ``annotations`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListAnnotationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., dataset_service.ListAnnotationsResponse], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListAnnotationsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListAnnotationsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListAnnotationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[dataset_service.ListAnnotationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[annotation.Annotation]: + for page in self.pages: + yield from page.annotations + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListAnnotationsAsyncPager: + """A pager for iterating through ``list_annotations`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListAnnotationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``annotations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListAnnotations`` requests and continue to iterate + through the ``annotations`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListAnnotationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[dataset_service.ListAnnotationsResponse]], + request: dataset_service.ListAnnotationsRequest, + response: dataset_service.ListAnnotationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListAnnotationsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListAnnotationsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = dataset_service.ListAnnotationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[dataset_service.ListAnnotationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[annotation.Annotation]: + async def async_generator(): + async for page in self.pages: + for response in page.annotations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py new file mode 100644 index 0000000000..a4461d2ced --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import DatasetServiceTransport +from .grpc import DatasetServiceGrpcTransport +from .grpc_asyncio import DatasetServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[DatasetServiceTransport]] +_transport_registry["grpc"] = DatasetServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport + +__all__ = ( + "DatasetServiceTransport", + "DatasetServiceGrpcTransport", + "DatasetServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py new file mode 100644 index 0000000000..2ab4419d03 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/base.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class DatasetServiceTransport(abc.ABC): + """Abstract transport class for DatasetService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_dataset: gapic_v1.method.wrap_method( + self.create_dataset, default_timeout=None, client_info=client_info, + ), + self.get_dataset: gapic_v1.method.wrap_method( + self.get_dataset, default_timeout=None, client_info=client_info, + ), + self.update_dataset: gapic_v1.method.wrap_method( + self.update_dataset, default_timeout=None, client_info=client_info, + ), + self.list_datasets: gapic_v1.method.wrap_method( + self.list_datasets, default_timeout=None, client_info=client_info, + ), + self.delete_dataset: gapic_v1.method.wrap_method( + self.delete_dataset, default_timeout=None, client_info=client_info, + ), + self.import_data: gapic_v1.method.wrap_method( + self.import_data, default_timeout=None, client_info=client_info, + ), + self.export_data: gapic_v1.method.wrap_method( + self.export_data, default_timeout=None, client_info=client_info, + ), + self.list_data_items: gapic_v1.method.wrap_method( + self.list_data_items, default_timeout=None, client_info=client_info, + ), + self.get_annotation_spec: gapic_v1.method.wrap_method( + self.get_annotation_spec, default_timeout=None, client_info=client_info, + ), + self.list_annotations: gapic_v1.method.wrap_method( + self.list_annotations, default_timeout=None, client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_dataset( + self, + ) -> typing.Callable[ + [dataset_service.CreateDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_dataset( + self, + ) -> typing.Callable[ + [dataset_service.GetDatasetRequest], + typing.Union[dataset.Dataset, typing.Awaitable[dataset.Dataset]], + ]: + raise NotImplementedError() + + @property + def update_dataset( + self, + ) -> typing.Callable[ + [dataset_service.UpdateDatasetRequest], + typing.Union[gca_dataset.Dataset, typing.Awaitable[gca_dataset.Dataset]], + ]: + raise NotImplementedError() + + @property + def list_datasets( + self, + ) -> typing.Callable[ + [dataset_service.ListDatasetsRequest], + typing.Union[ + dataset_service.ListDatasetsResponse, + typing.Awaitable[dataset_service.ListDatasetsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_dataset( + self, + ) -> typing.Callable[ + [dataset_service.DeleteDatasetRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def import_data( + self, + ) -> typing.Callable[ + [dataset_service.ImportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def export_data( + self, + ) -> typing.Callable[ + [dataset_service.ExportDataRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def list_data_items( + self, + ) -> typing.Callable[ + [dataset_service.ListDataItemsRequest], + typing.Union[ + dataset_service.ListDataItemsResponse, + typing.Awaitable[dataset_service.ListDataItemsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_annotation_spec( + self, + ) -> typing.Callable[ + [dataset_service.GetAnnotationSpecRequest], + typing.Union[ + annotation_spec.AnnotationSpec, + typing.Awaitable[annotation_spec.AnnotationSpec], + ], + ]: + raise NotImplementedError() + + @property + def list_annotations( + self, + ) -> typing.Callable[ + [dataset_service.ListAnnotationsRequest], + typing.Union[ + dataset_service.ListAnnotationsResponse, + typing.Awaitable[dataset_service.ListAnnotationsResponse], + ], + ]: + raise NotImplementedError() + + +__all__ = ("DatasetServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py new file mode 100644 index 0000000000..e5a54388cb --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc.py @@ -0,0 +1,539 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO + + +class DatasetServiceGrpcTransport(DatasetServiceTransport): + """gRPC backend transport for DatasetService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_dataset( + self, + ) -> Callable[[dataset_service.CreateDatasetRequest], operations.Operation]: + r"""Return a callable for the create dataset method over gRPC. + + Creates a Dataset. + + Returns: + Callable[[~.CreateDatasetRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", + request_serializer=dataset_service.CreateDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_dataset"] + + @property + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], dataset.Dataset]: + r"""Return a callable for the get dataset method over gRPC. + + Gets a Dataset. + + Returns: + Callable[[~.GetDatasetRequest], + ~.Dataset]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetDataset", + request_serializer=dataset_service.GetDatasetRequest.serialize, + response_deserializer=dataset.Dataset.deserialize, + ) + return self._stubs["get_dataset"] + + @property + def update_dataset( + self, + ) -> Callable[[dataset_service.UpdateDatasetRequest], gca_dataset.Dataset]: + r"""Return a callable for the update dataset method over gRPC. + + Updates a Dataset. + + Returns: + Callable[[~.UpdateDatasetRequest], + ~.Dataset]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", + request_serializer=dataset_service.UpdateDatasetRequest.serialize, + response_deserializer=gca_dataset.Dataset.deserialize, + ) + return self._stubs["update_dataset"] + + @property + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], dataset_service.ListDatasetsResponse + ]: + r"""Return a callable for the list datasets method over gRPC. + + Lists Datasets in a Location. + + Returns: + Callable[[~.ListDatasetsRequest], + ~.ListDatasetsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", + request_serializer=dataset_service.ListDatasetsRequest.serialize, + response_deserializer=dataset_service.ListDatasetsResponse.deserialize, + ) + return self._stubs["list_datasets"] + + @property + def delete_dataset( + self, + ) -> Callable[[dataset_service.DeleteDatasetRequest], operations.Operation]: + r"""Return a callable for the delete dataset method over gRPC. + + Deletes a Dataset. + + Returns: + Callable[[~.DeleteDatasetRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", + request_serializer=dataset_service.DeleteDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_dataset"] + + @property + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], operations.Operation]: + r"""Return a callable for the import data method over gRPC. + + Imports data into a Dataset. + + Returns: + Callable[[~.ImportDataRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ImportData", + request_serializer=dataset_service.ImportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["import_data"] + + @property + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], operations.Operation]: + r"""Return a callable for the export data method over gRPC. + + Exports data from a Dataset. + + Returns: + Callable[[~.ExportDataRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ExportData", + request_serializer=dataset_service.ExportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_data"] + + @property + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], dataset_service.ListDataItemsResponse + ]: + r"""Return a callable for the list data items method over gRPC. + + Lists DataItems in a Dataset. + + Returns: + Callable[[~.ListDataItemsRequest], + ~.ListDataItemsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", + request_serializer=dataset_service.ListDataItemsRequest.serialize, + response_deserializer=dataset_service.ListDataItemsResponse.deserialize, + ) + return self._stubs["list_data_items"] + + @property + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], annotation_spec.AnnotationSpec + ]: + r"""Return a callable for the get annotation spec method over gRPC. + + Gets an AnnotationSpec. + + Returns: + Callable[[~.GetAnnotationSpecRequest], + ~.AnnotationSpec]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", + request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, + response_deserializer=annotation_spec.AnnotationSpec.deserialize, + ) + return self._stubs["get_annotation_spec"] + + @property + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + dataset_service.ListAnnotationsResponse, + ]: + r"""Return a callable for the list annotations method over gRPC. + + Lists Annotations belongs to a dataitem + + Returns: + Callable[[~.ListAnnotationsRequest], + ~.ListAnnotationsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", + request_serializer=dataset_service.ListAnnotationsRequest.serialize, + response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, + ) + return self._stubs["list_annotations"] + + +__all__ = ("DatasetServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..bcf3331d6b --- /dev/null +++ b/google/cloud/aiplatform_v1/services/dataset_service/transports/grpc_asyncio.py @@ -0,0 +1,554 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import DatasetServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import DatasetServiceGrpcTransport + + +class DatasetServiceGrpcAsyncIOTransport(DatasetServiceTransport): + """gRPC AsyncIO backend transport for DatasetService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_dataset( + self, + ) -> Callable[ + [dataset_service.CreateDatasetRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the create dataset method over gRPC. + + Creates a Dataset. + + Returns: + Callable[[~.CreateDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_dataset" not in self._stubs: + self._stubs["create_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/CreateDataset", + request_serializer=dataset_service.CreateDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_dataset"] + + @property + def get_dataset( + self, + ) -> Callable[[dataset_service.GetDatasetRequest], Awaitable[dataset.Dataset]]: + r"""Return a callable for the get dataset method over gRPC. + + Gets a Dataset. + + Returns: + Callable[[~.GetDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_dataset" not in self._stubs: + self._stubs["get_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetDataset", + request_serializer=dataset_service.GetDatasetRequest.serialize, + response_deserializer=dataset.Dataset.deserialize, + ) + return self._stubs["get_dataset"] + + @property + def update_dataset( + self, + ) -> Callable[ + [dataset_service.UpdateDatasetRequest], Awaitable[gca_dataset.Dataset] + ]: + r"""Return a callable for the update dataset method over gRPC. + + Updates a Dataset. + + Returns: + Callable[[~.UpdateDatasetRequest], + Awaitable[~.Dataset]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_dataset" not in self._stubs: + self._stubs["update_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/UpdateDataset", + request_serializer=dataset_service.UpdateDatasetRequest.serialize, + response_deserializer=gca_dataset.Dataset.deserialize, + ) + return self._stubs["update_dataset"] + + @property + def list_datasets( + self, + ) -> Callable[ + [dataset_service.ListDatasetsRequest], + Awaitable[dataset_service.ListDatasetsResponse], + ]: + r"""Return a callable for the list datasets method over gRPC. + + Lists Datasets in a Location. + + Returns: + Callable[[~.ListDatasetsRequest], + Awaitable[~.ListDatasetsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_datasets" not in self._stubs: + self._stubs["list_datasets"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDatasets", + request_serializer=dataset_service.ListDatasetsRequest.serialize, + response_deserializer=dataset_service.ListDatasetsResponse.deserialize, + ) + return self._stubs["list_datasets"] + + @property + def delete_dataset( + self, + ) -> Callable[ + [dataset_service.DeleteDatasetRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete dataset method over gRPC. + + Deletes a Dataset. + + Returns: + Callable[[~.DeleteDatasetRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_dataset" not in self._stubs: + self._stubs["delete_dataset"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/DeleteDataset", + request_serializer=dataset_service.DeleteDatasetRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_dataset"] + + @property + def import_data( + self, + ) -> Callable[[dataset_service.ImportDataRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the import data method over gRPC. + + Imports data into a Dataset. + + Returns: + Callable[[~.ImportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "import_data" not in self._stubs: + self._stubs["import_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ImportData", + request_serializer=dataset_service.ImportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["import_data"] + + @property + def export_data( + self, + ) -> Callable[[dataset_service.ExportDataRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the export data method over gRPC. + + Exports data from a Dataset. + + Returns: + Callable[[~.ExportDataRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_data" not in self._stubs: + self._stubs["export_data"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ExportData", + request_serializer=dataset_service.ExportDataRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_data"] + + @property + def list_data_items( + self, + ) -> Callable[ + [dataset_service.ListDataItemsRequest], + Awaitable[dataset_service.ListDataItemsResponse], + ]: + r"""Return a callable for the list data items method over gRPC. + + Lists DataItems in a Dataset. + + Returns: + Callable[[~.ListDataItemsRequest], + Awaitable[~.ListDataItemsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_items" not in self._stubs: + self._stubs["list_data_items"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListDataItems", + request_serializer=dataset_service.ListDataItemsRequest.serialize, + response_deserializer=dataset_service.ListDataItemsResponse.deserialize, + ) + return self._stubs["list_data_items"] + + @property + def get_annotation_spec( + self, + ) -> Callable[ + [dataset_service.GetAnnotationSpecRequest], + Awaitable[annotation_spec.AnnotationSpec], + ]: + r"""Return a callable for the get annotation spec method over gRPC. + + Gets an AnnotationSpec. + + Returns: + Callable[[~.GetAnnotationSpecRequest], + Awaitable[~.AnnotationSpec]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_annotation_spec" not in self._stubs: + self._stubs["get_annotation_spec"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/GetAnnotationSpec", + request_serializer=dataset_service.GetAnnotationSpecRequest.serialize, + response_deserializer=annotation_spec.AnnotationSpec.deserialize, + ) + return self._stubs["get_annotation_spec"] + + @property + def list_annotations( + self, + ) -> Callable[ + [dataset_service.ListAnnotationsRequest], + Awaitable[dataset_service.ListAnnotationsResponse], + ]: + r"""Return a callable for the list annotations method over gRPC. + + Lists Annotations belongs to a dataitem + + Returns: + Callable[[~.ListAnnotationsRequest], + Awaitable[~.ListAnnotationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_annotations" not in self._stubs: + self._stubs["list_annotations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.DatasetService/ListAnnotations", + request_serializer=dataset_service.ListAnnotationsRequest.serialize, + response_deserializer=dataset_service.ListAnnotationsResponse.deserialize, + ) + return self._stubs["list_annotations"] + + +__all__ = ("DatasetServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py new file mode 100644 index 0000000000..035a5b2388 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import EndpointServiceClient +from .async_client import EndpointServiceAsyncClient + +__all__ = ( + "EndpointServiceClient", + "EndpointServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py new file mode 100644 index 0000000000..e36aa6dfde --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/async_client.py @@ -0,0 +1,841 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport +from .client import EndpointServiceClient + + +class EndpointServiceAsyncClient: + """""" + + _client: EndpointServiceClient + + DEFAULT_ENDPOINT = EndpointServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = EndpointServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(EndpointServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(EndpointServiceClient.parse_endpoint_path) + model_path = staticmethod(EndpointServiceClient.model_path) + parse_model_path = staticmethod(EndpointServiceClient.parse_model_path) + + common_billing_account_path = staticmethod( + EndpointServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + EndpointServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(EndpointServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + EndpointServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + EndpointServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + EndpointServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(EndpointServiceClient.common_project_path) + parse_common_project_path = staticmethod( + EndpointServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(EndpointServiceClient.common_location_path) + parse_common_location_path = staticmethod( + EndpointServiceClient.parse_common_location_path + ) + + from_service_account_info = EndpointServiceClient.from_service_account_info + from_service_account_file = EndpointServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(EndpointServiceClient).get_transport_class, type(EndpointServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, EndpointServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.EndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = EndpointServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates an Endpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateEndpointRequest`): + The request object. Request message for + ``EndpointService.CreateEndpoint``. + parent (:class:`str`): + Required. The resource name of the Location to create + the Endpoint in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + endpoint (:class:`google.cloud.aiplatform_v1.types.Endpoint`): + Required. The Endpoint to create. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.Endpoint` Models are deployed into it, and afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.CreateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_endpoint.Endpoint, + metadata_type=endpoint_service.CreateEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: + r"""Gets an Endpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetEndpointRequest`): + The request object. Request message for + ``EndpointService.GetEndpoint`` + name (:class:`str`): + Required. The name of the Endpoint resource. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.GetEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsAsyncPager: + r"""Lists Endpoints in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListEndpointsRequest`): + The request object. Request message for + ``EndpointService.ListEndpoints``. + parent (:class:`str`): + Required. The resource name of the Location from which + to list the Endpoints. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.endpoint_service.pagers.ListEndpointsAsyncPager: + Response message for + ``EndpointService.ListEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.ListEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_endpoints, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListEndpointsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: + r"""Updates an Endpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UpdateEndpointRequest`): + The request object. Request message for + ``EndpointService.UpdateEndpoint``. + endpoint (:class:`google.cloud.aiplatform_v1.types.Endpoint`): + Required. The Endpoint which replaces + the resource on the server. + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.UpdateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes an Endpoint. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteEndpointRequest`): + The request object. Request message for + ``EndpointService.DeleteEndpoint``. + name (:class:`str`): + Required. The name of the Endpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.DeleteEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_endpoint, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeployModelRequest`): + The request object. Request message for + ``EndpointService.DeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource into which + to deploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (:class:`google.cloud.aiplatform_v1.types.DeployedModel`): + Required. The DeployedModel to be created within the + Endpoint. Note that + ``Endpoint.traffic_split`` + must be updated for the DeployedModel to start receiving + traffic, either as part of this call, or via + ``EndpointService.UpdateEndpoint``. + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[google.cloud.aiplatform_v1.types.DeployModelRequest.TrafficSplitEntry]`): + A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that + DeployedModel. + + If this field is non-empty, then the Endpoint's + ``traffic_split`` + will be overwritten with it. To refer to the ID of the + just being deployed Model, a "0" should be used, and the + actual ID of the new DeployedModel will be filled in its + place by this method. The traffic percentage values must + add up to 100. + + If this field is empty, then the Endpoint's + ``traffic_split`` + is not updated. + + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.DeployModelResponse` + Response message for + ``EndpointService.DeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.DeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.deploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.DeployModelResponse, + metadata_type=endpoint_service.DeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UndeployModelRequest`): + The request object. Request message for + ``EndpointService.UndeployModel``. + endpoint (:class:`str`): + Required. The name of the Endpoint resource from which + to undeploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (:class:`str`): + Required. The ID of the DeployedModel + to be undeployed from the Endpoint. + + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (:class:`Sequence[google.cloud.aiplatform_v1.types.UndeployModelRequest.TrafficSplitEntry]`): + If this field is provided, then the Endpoint's + ``traffic_split`` + will be overwritten with it. If last DeployedModel is + being undeployed from the Endpoint, the + [Endpoint.traffic_split] will always end up empty when + this call returns. A DeployedModel will be successfully + undeployed only if it doesn't have any traffic assigned + to it when this method executes, or if this field + unassigns any traffic to it. + + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.UndeployModelResponse` + Response message for + ``EndpointService.UndeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = endpoint_service.UndeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.undeploy_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + endpoint_service.UndeployModelResponse, + metadata_type=endpoint_service.UndeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("EndpointServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1/services/endpoint_service/client.py new file mode 100644 index 0000000000..1316effa58 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/client.py @@ -0,0 +1,1064 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import EndpointServiceGrpcTransport +from .transports.grpc_asyncio import EndpointServiceGrpcAsyncIOTransport + + +class EndpointServiceClientMeta(type): + """Metaclass for the EndpointService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[EndpointServiceTransport]] + _transport_registry["grpc"] = EndpointServiceGrpcTransport + _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[EndpointServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class EndpointServiceClient(metaclass=EndpointServiceClientMeta): + """""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> EndpointServiceTransport: + """Return the transport used by the client instance. + + Returns: + EndpointServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, EndpointServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the endpoint service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, EndpointServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, EndpointServiceTransport): + # transport is a EndpointServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_endpoint( + self, + request: endpoint_service.CreateEndpointRequest = None, + *, + parent: str = None, + endpoint: gca_endpoint.Endpoint = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Creates an Endpoint. + + Args: + request (google.cloud.aiplatform_v1.types.CreateEndpointRequest): + The request object. Request message for + ``EndpointService.CreateEndpoint``. + parent (str): + Required. The resource name of the Location to create + the Endpoint in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + endpoint (google.cloud.aiplatform_v1.types.Endpoint): + Required. The Endpoint to create. + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.Endpoint` Models are deployed into it, and afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, endpoint]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.CreateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.CreateEndpointRequest): + request = endpoint_service.CreateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if endpoint is not None: + request.endpoint = endpoint + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_endpoint.Endpoint, + metadata_type=endpoint_service.CreateEndpointOperationMetadata, + ) + + # Done; return the response. + return response + + def get_endpoint( + self, + request: endpoint_service.GetEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> endpoint.Endpoint: + r"""Gets an Endpoint. + + Args: + request (google.cloud.aiplatform_v1.types.GetEndpointRequest): + The request object. Request message for + ``EndpointService.GetEndpoint`` + name (str): + Required. The name of the Endpoint resource. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.GetEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.GetEndpointRequest): + request = endpoint_service.GetEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_endpoints( + self, + request: endpoint_service.ListEndpointsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListEndpointsPager: + r"""Lists Endpoints in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListEndpointsRequest): + The request object. Request message for + ``EndpointService.ListEndpoints``. + parent (str): + Required. The resource name of the Location from which + to list the Endpoints. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.endpoint_service.pagers.ListEndpointsPager: + Response message for + ``EndpointService.ListEndpoints``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.ListEndpointsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.ListEndpointsRequest): + request = endpoint_service.ListEndpointsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_endpoints] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListEndpointsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def update_endpoint( + self, + request: endpoint_service.UpdateEndpointRequest = None, + *, + endpoint: gca_endpoint.Endpoint = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_endpoint.Endpoint: + r"""Updates an Endpoint. + + Args: + request (google.cloud.aiplatform_v1.types.UpdateEndpointRequest): + The request object. Request message for + ``EndpointService.UpdateEndpoint``. + endpoint (google.cloud.aiplatform_v1.types.Endpoint): + Required. The Endpoint which replaces + the resource on the server. + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Endpoint: + Models are deployed into it, and + afterwards Endpoint is called to obtain + predictions and explanations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UpdateEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UpdateEndpointRequest): + request = endpoint_service.UpdateEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("endpoint.name", request.endpoint.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def delete_endpoint( + self, + request: endpoint_service.DeleteEndpointRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes an Endpoint. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteEndpointRequest): + The request object. Request message for + ``EndpointService.DeleteEndpoint``. + name (str): + Required. The name of the Endpoint resource to be + deleted. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeleteEndpointRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeleteEndpointRequest): + request = endpoint_service.DeleteEndpointRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_endpoint] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def deploy_model( + self, + request: endpoint_service.DeployModelRequest = None, + *, + endpoint: str = None, + deployed_model: gca_endpoint.DeployedModel = None, + traffic_split: Sequence[ + endpoint_service.DeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Args: + request (google.cloud.aiplatform_v1.types.DeployModelRequest): + The request object. Request message for + ``EndpointService.DeployModel``. + endpoint (str): + Required. The name of the Endpoint resource into which + to deploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + Required. The DeployedModel to be created within the + Endpoint. Note that + ``Endpoint.traffic_split`` + must be updated for the DeployedModel to start receiving + traffic, either as part of this call, or via + ``EndpointService.UpdateEndpoint``. + + This corresponds to the ``deployed_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (Sequence[google.cloud.aiplatform_v1.types.DeployModelRequest.TrafficSplitEntry]): + A map from a DeployedModel's ID to the percentage of + this Endpoint's traffic that should be forwarded to that + DeployedModel. + + If this field is non-empty, then the Endpoint's + ``traffic_split`` + will be overwritten with it. To refer to the ID of the + just being deployed Model, a "0" should be used, and the + actual ID of the new DeployedModel will be filled in its + place by this method. The traffic percentage values must + add up to 100. + + If this field is empty, then the Endpoint's + ``traffic_split`` + is not updated. + + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.DeployModelResponse` + Response message for + ``EndpointService.DeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.DeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.DeployModelRequest): + request = endpoint_service.DeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model is not None: + request.deployed_model = deployed_model + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.deploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + endpoint_service.DeployModelResponse, + metadata_type=endpoint_service.DeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + def undeploy_model( + self, + request: endpoint_service.UndeployModelRequest = None, + *, + endpoint: str = None, + deployed_model_id: str = None, + traffic_split: Sequence[ + endpoint_service.UndeployModelRequest.TrafficSplitEntry + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Args: + request (google.cloud.aiplatform_v1.types.UndeployModelRequest): + The request object. Request message for + ``EndpointService.UndeployModel``. + endpoint (str): + Required. The name of the Endpoint resource from which + to undeploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + deployed_model_id (str): + Required. The ID of the DeployedModel + to be undeployed from the Endpoint. + + This corresponds to the ``deployed_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + traffic_split (Sequence[google.cloud.aiplatform_v1.types.UndeployModelRequest.TrafficSplitEntry]): + If this field is provided, then the Endpoint's + ``traffic_split`` + will be overwritten with it. If last DeployedModel is + being undeployed from the Endpoint, the + [Endpoint.traffic_split] will always end up empty when + this call returns. A DeployedModel will be successfully + undeployed only if it doesn't have any traffic assigned + to it when this method executes, or if this field + unassigns any traffic to it. + + This corresponds to the ``traffic_split`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.UndeployModelResponse` + Response message for + ``EndpointService.UndeployModel``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, deployed_model_id, traffic_split]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a endpoint_service.UndeployModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, endpoint_service.UndeployModelRequest): + request = endpoint_service.UndeployModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if deployed_model_id is not None: + request.deployed_model_id = deployed_model_id + + if traffic_split: + request.traffic_split.update(traffic_split) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.undeploy_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + endpoint_service.UndeployModelResponse, + metadata_type=endpoint_service.UndeployModelOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("EndpointServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py new file mode 100644 index 0000000000..01ebccdec3 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/pagers.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint_service + + +class ListEndpointsPager: + """A pager for iterating through ``list_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListEndpointsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``endpoints`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListEndpoints`` requests and continue to iterate + through the ``endpoints`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., endpoint_service.ListEndpointsResponse], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListEndpointsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListEndpointsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = endpoint_service.ListEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[endpoint_service.ListEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[endpoint.Endpoint]: + for page in self.pages: + yield from page.endpoints + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListEndpointsAsyncPager: + """A pager for iterating through ``list_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListEndpointsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``endpoints`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListEndpoints`` requests and continue to iterate + through the ``endpoints`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListEndpointsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[endpoint_service.ListEndpointsResponse]], + request: endpoint_service.ListEndpointsRequest, + response: endpoint_service.ListEndpointsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListEndpointsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListEndpointsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = endpoint_service.ListEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[endpoint_service.ListEndpointsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[endpoint.Endpoint]: + async def async_generator(): + async for page in self.pages: + for response in page.endpoints: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py new file mode 100644 index 0000000000..3d0695461d --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import EndpointServiceTransport +from .grpc import EndpointServiceGrpcTransport +from .grpc_asyncio import EndpointServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[EndpointServiceTransport]] +_transport_registry["grpc"] = EndpointServiceGrpcTransport +_transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport + +__all__ = ( + "EndpointServiceTransport", + "EndpointServiceGrpcTransport", + "EndpointServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py new file mode 100644 index 0000000000..728c38fec3 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/base.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class EndpointServiceTransport(abc.ABC): + """Abstract transport class for EndpointService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_endpoint: gapic_v1.method.wrap_method( + self.create_endpoint, default_timeout=None, client_info=client_info, + ), + self.get_endpoint: gapic_v1.method.wrap_method( + self.get_endpoint, default_timeout=None, client_info=client_info, + ), + self.list_endpoints: gapic_v1.method.wrap_method( + self.list_endpoints, default_timeout=None, client_info=client_info, + ), + self.update_endpoint: gapic_v1.method.wrap_method( + self.update_endpoint, default_timeout=None, client_info=client_info, + ), + self.delete_endpoint: gapic_v1.method.wrap_method( + self.delete_endpoint, default_timeout=None, client_info=client_info, + ), + self.deploy_model: gapic_v1.method.wrap_method( + self.deploy_model, default_timeout=None, client_info=client_info, + ), + self.undeploy_model: gapic_v1.method.wrap_method( + self.undeploy_model, default_timeout=None, client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.CreateEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.GetEndpointRequest], + typing.Union[endpoint.Endpoint, typing.Awaitable[endpoint.Endpoint]], + ]: + raise NotImplementedError() + + @property + def list_endpoints( + self, + ) -> typing.Callable[ + [endpoint_service.ListEndpointsRequest], + typing.Union[ + endpoint_service.ListEndpointsResponse, + typing.Awaitable[endpoint_service.ListEndpointsResponse], + ], + ]: + raise NotImplementedError() + + @property + def update_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.UpdateEndpointRequest], + typing.Union[gca_endpoint.Endpoint, typing.Awaitable[gca_endpoint.Endpoint]], + ]: + raise NotImplementedError() + + @property + def delete_endpoint( + self, + ) -> typing.Callable[ + [endpoint_service.DeleteEndpointRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def deploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.DeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def undeploy_model( + self, + ) -> typing.Callable[ + [endpoint_service.UndeployModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + +__all__ = ("EndpointServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py new file mode 100644 index 0000000000..f0b8b32de1 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc.py @@ -0,0 +1,456 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO + + +class EndpointServiceGrpcTransport(EndpointServiceTransport): + """gRPC backend transport for EndpointService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_endpoint( + self, + ) -> Callable[[endpoint_service.CreateEndpointRequest], operations.Operation]: + r"""Return a callable for the create endpoint method over gRPC. + + Creates an Endpoint. + + Returns: + Callable[[~.CreateEndpointRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", + request_serializer=endpoint_service.CreateEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_endpoint"] + + @property + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], endpoint.Endpoint]: + r"""Return a callable for the get endpoint method over gRPC. + + Gets an Endpoint. + + Returns: + Callable[[~.GetEndpointRequest], + ~.Endpoint]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", + request_serializer=endpoint_service.GetEndpointRequest.serialize, + response_deserializer=endpoint.Endpoint.deserialize, + ) + return self._stubs["get_endpoint"] + + @property + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], endpoint_service.ListEndpointsResponse + ]: + r"""Return a callable for the list endpoints method over gRPC. + + Lists Endpoints in a Location. + + Returns: + Callable[[~.ListEndpointsRequest], + ~.ListEndpointsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", + request_serializer=endpoint_service.ListEndpointsRequest.serialize, + response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, + ) + return self._stubs["list_endpoints"] + + @property + def update_endpoint( + self, + ) -> Callable[[endpoint_service.UpdateEndpointRequest], gca_endpoint.Endpoint]: + r"""Return a callable for the update endpoint method over gRPC. + + Updates an Endpoint. + + Returns: + Callable[[~.UpdateEndpointRequest], + ~.Endpoint]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", + request_serializer=endpoint_service.UpdateEndpointRequest.serialize, + response_deserializer=gca_endpoint.Endpoint.deserialize, + ) + return self._stubs["update_endpoint"] + + @property + def delete_endpoint( + self, + ) -> Callable[[endpoint_service.DeleteEndpointRequest], operations.Operation]: + r"""Return a callable for the delete endpoint method over gRPC. + + Deletes an Endpoint. + + Returns: + Callable[[~.DeleteEndpointRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", + request_serializer=endpoint_service.DeleteEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_endpoint"] + + @property + def deploy_model( + self, + ) -> Callable[[endpoint_service.DeployModelRequest], operations.Operation]: + r"""Return a callable for the deploy model method over gRPC. + + Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Returns: + Callable[[~.DeployModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeployModel", + request_serializer=endpoint_service.DeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["deploy_model"] + + @property + def undeploy_model( + self, + ) -> Callable[[endpoint_service.UndeployModelRequest], operations.Operation]: + r"""Return a callable for the undeploy model method over gRPC. + + Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", + request_serializer=endpoint_service.UndeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["undeploy_model"] + + +__all__ = ("EndpointServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..ef97ba490f --- /dev/null +++ b/google/cloud/aiplatform_v1/services/endpoint_service/transports/grpc_asyncio.py @@ -0,0 +1,473 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import EndpointServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import EndpointServiceGrpcTransport + + +class EndpointServiceGrpcAsyncIOTransport(EndpointServiceTransport): + """gRPC AsyncIO backend transport for EndpointService. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_endpoint( + self, + ) -> Callable[ + [endpoint_service.CreateEndpointRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the create endpoint method over gRPC. + + Creates an Endpoint. + + Returns: + Callable[[~.CreateEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_endpoint" not in self._stubs: + self._stubs["create_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/CreateEndpoint", + request_serializer=endpoint_service.CreateEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_endpoint"] + + @property + def get_endpoint( + self, + ) -> Callable[[endpoint_service.GetEndpointRequest], Awaitable[endpoint.Endpoint]]: + r"""Return a callable for the get endpoint method over gRPC. + + Gets an Endpoint. + + Returns: + Callable[[~.GetEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_endpoint" not in self._stubs: + self._stubs["get_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/GetEndpoint", + request_serializer=endpoint_service.GetEndpointRequest.serialize, + response_deserializer=endpoint.Endpoint.deserialize, + ) + return self._stubs["get_endpoint"] + + @property + def list_endpoints( + self, + ) -> Callable[ + [endpoint_service.ListEndpointsRequest], + Awaitable[endpoint_service.ListEndpointsResponse], + ]: + r"""Return a callable for the list endpoints method over gRPC. + + Lists Endpoints in a Location. + + Returns: + Callable[[~.ListEndpointsRequest], + Awaitable[~.ListEndpointsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_endpoints" not in self._stubs: + self._stubs["list_endpoints"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/ListEndpoints", + request_serializer=endpoint_service.ListEndpointsRequest.serialize, + response_deserializer=endpoint_service.ListEndpointsResponse.deserialize, + ) + return self._stubs["list_endpoints"] + + @property + def update_endpoint( + self, + ) -> Callable[ + [endpoint_service.UpdateEndpointRequest], Awaitable[gca_endpoint.Endpoint] + ]: + r"""Return a callable for the update endpoint method over gRPC. + + Updates an Endpoint. + + Returns: + Callable[[~.UpdateEndpointRequest], + Awaitable[~.Endpoint]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_endpoint" not in self._stubs: + self._stubs["update_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UpdateEndpoint", + request_serializer=endpoint_service.UpdateEndpointRequest.serialize, + response_deserializer=gca_endpoint.Endpoint.deserialize, + ) + return self._stubs["update_endpoint"] + + @property + def delete_endpoint( + self, + ) -> Callable[ + [endpoint_service.DeleteEndpointRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete endpoint method over gRPC. + + Deletes an Endpoint. + + Returns: + Callable[[~.DeleteEndpointRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_endpoint" not in self._stubs: + self._stubs["delete_endpoint"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeleteEndpoint", + request_serializer=endpoint_service.DeleteEndpointRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_endpoint"] + + @property + def deploy_model( + self, + ) -> Callable[ + [endpoint_service.DeployModelRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the deploy model method over gRPC. + + Deploys a Model into this Endpoint, creating a + DeployedModel within it. + + Returns: + Callable[[~.DeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "deploy_model" not in self._stubs: + self._stubs["deploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/DeployModel", + request_serializer=endpoint_service.DeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["deploy_model"] + + @property + def undeploy_model( + self, + ) -> Callable[ + [endpoint_service.UndeployModelRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the undeploy model method over gRPC. + + Undeploys a Model from an Endpoint, removing a + DeployedModel from it, and freeing all resources it's + using. + + Returns: + Callable[[~.UndeployModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "undeploy_model" not in self._stubs: + self._stubs["undeploy_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.EndpointService/UndeployModel", + request_serializer=endpoint_service.UndeployModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["undeploy_model"] + + +__all__ = ("EndpointServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/__init__.py b/google/cloud/aiplatform_v1/services/job_service/__init__.py new file mode 100644 index 0000000000..5f157047f5 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import JobServiceClient +from .async_client import JobServiceAsyncClient + +__all__ = ( + "JobServiceClient", + "JobServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/job_service/async_client.py b/google/cloud/aiplatform_v1/services/job_service/async_client.py new file mode 100644 index 0000000000..689cb131ea --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/async_client.py @@ -0,0 +1,1874 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.job_service import pagers +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import completion_stats +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import job_service +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import study +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport +from .client import JobServiceClient + + +class JobServiceAsyncClient: + """A service for creating and managing AI Platform's jobs.""" + + _client: JobServiceClient + + DEFAULT_ENDPOINT = JobServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = JobServiceClient.DEFAULT_MTLS_ENDPOINT + + batch_prediction_job_path = staticmethod(JobServiceClient.batch_prediction_job_path) + parse_batch_prediction_job_path = staticmethod( + JobServiceClient.parse_batch_prediction_job_path + ) + custom_job_path = staticmethod(JobServiceClient.custom_job_path) + parse_custom_job_path = staticmethod(JobServiceClient.parse_custom_job_path) + data_labeling_job_path = staticmethod(JobServiceClient.data_labeling_job_path) + parse_data_labeling_job_path = staticmethod( + JobServiceClient.parse_data_labeling_job_path + ) + dataset_path = staticmethod(JobServiceClient.dataset_path) + parse_dataset_path = staticmethod(JobServiceClient.parse_dataset_path) + hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.hyperparameter_tuning_job_path + ) + parse_hyperparameter_tuning_job_path = staticmethod( + JobServiceClient.parse_hyperparameter_tuning_job_path + ) + model_path = staticmethod(JobServiceClient.model_path) + parse_model_path = staticmethod(JobServiceClient.parse_model_path) + trial_path = staticmethod(JobServiceClient.trial_path) + parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) + + common_billing_account_path = staticmethod( + JobServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + JobServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(JobServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(JobServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(JobServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + JobServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(JobServiceClient.common_project_path) + parse_common_project_path = staticmethod(JobServiceClient.parse_common_project_path) + + common_location_path = staticmethod(JobServiceClient.common_location_path) + parse_common_location_path = staticmethod( + JobServiceClient.parse_common_location_path + ) + + from_service_account_info = JobServiceClient.from_service_account_info + from_service_account_file = JobServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(JobServiceClient).get_transport_class, type(JobServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, JobServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the job service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.JobServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = JobServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: + r"""Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateCustomJobRequest`): + The request object. Request message for + ``JobService.CreateCustomJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + custom_job (:class:`google.cloud.aiplatform_v1.types.CustomJob`): + Required. The CustomJob to create. + This corresponds to the ``custom_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: + r"""Gets a CustomJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetCustomJobRequest`): + The request object. Request message for + ``JobService.GetCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsAsyncPager: + r"""Lists CustomJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListCustomJobsRequest`): + The request object. Request message for + ``JobService.ListCustomJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + CustomJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListCustomJobsAsyncPager: + Response message for + ``JobService.ListCustomJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListCustomJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_custom_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListCustomJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a CustomJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteCustomJobRequest`): + The request object. Request message for + ``JobService.DeleteCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` is + set to ``CANCELLED``. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CancelCustomJobRequest`): + The request object. Request message for + ``JobService.CancelCustomJob``. + name (:class:`str`): + Required. The name of the CustomJob to cancel. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_custom_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: + r"""Creates a DataLabelingJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CreateDataLabelingJob][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + data_labeling_job (:class:`google.cloud.aiplatform_v1.types.DataLabelingJob`): + Required. The DataLabelingJob to + create. + + This corresponds to the ``data_labeling_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: + r"""Gets a DataLabelingJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.GetDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsAsyncPager: + r"""Lists DataLabelingJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListDataLabelingJobsRequest`): + The request object. Request message for + [DataLabelingJobService.ListDataLabelingJobs][]. + parent (:class:`str`): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListDataLabelingJobsAsyncPager: + Response message for + ``JobService.ListDataLabelingJobs``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListDataLabelingJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_data_labeling_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDataLabelingJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a DataLabelingJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteDataLabelingJobRequest`): + The request object. Request message for + ``JobService.DeleteDataLabelingJob``. + name (:class:`str`): + Required. The name of the DataLabelingJob to be deleted. + Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CancelDataLabelingJobRequest`): + The request object. Request message for + [DataLabelingJobService.CancelDataLabelingJob][]. + name (:class:`str`): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_data_labeling_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Creates a HyperparameterTuningJob + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CreateHyperparameterTuningJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the HyperparameterTuningJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + hyperparameter_tuning_job (:class:`google.cloud.aiplatform_v1.types.HyperparameterTuningJob`): + Required. The HyperparameterTuningJob + to create. + + This corresponds to the ``hyperparameter_tuning_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Gets a HyperparameterTuningJob + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.GetHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsAsyncPager: + r"""Lists HyperparameterTuningJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsRequest`): + The request object. Request message for + ``JobService.ListHyperparameterTuningJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + HyperparameterTuningJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListHyperparameterTuningJobsAsyncPager: + Response message for + ``JobService.ListHyperparameterTuningJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListHyperparameterTuningJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListHyperparameterTuningJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a HyperparameterTuningJob. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.DeleteHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob + resource to be deleted. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CancelHyperparameterTuningJobRequest`): + The request object. Request message for + ``JobService.CancelHyperparameterTuningJob``. + name (:class:`str`): + Required. The name of the HyperparameterTuningJob to + cancel. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + async def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: + r"""Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CreateBatchPredictionJob``. + parent (:class:`str`): + Required. The resource name of the Location to create + the BatchPredictionJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + batch_prediction_job (:class:`google.cloud.aiplatform_v1.types.BatchPredictionJob`): + Required. The BatchPredictionJob to + create. + + This corresponds to the ``batch_prediction_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CreateBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: + r"""Gets a BatchPredictionJob + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.GetBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.GetBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsAsyncPager: + r"""Lists BatchPredictionJobs in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListBatchPredictionJobsRequest`): + The request object. Request message for + ``JobService.ListBatchPredictionJobs``. + parent (:class:`str`): + Required. The resource name of the Location to list the + BatchPredictionJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListBatchPredictionJobsAsyncPager: + Response message for + ``JobService.ListBatchPredictionJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.ListBatchPredictionJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_batch_prediction_jobs, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListBatchPredictionJobsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.DeleteBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.DeleteBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CancelBatchPredictionJobRequest`): + The request object. Request message for + ``JobService.CancelBatchPredictionJob``. + name (:class:`str`): + Required. The name of the BatchPredictionJob to cancel. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = job_service.CancelBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_batch_prediction_job, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("JobServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/job_service/client.py b/google/cloud/aiplatform_v1/services/job_service/client.py new file mode 100644 index 0000000000..746ce91c4b --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/client.py @@ -0,0 +1,2204 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.job_service import pagers +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import completion_stats +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import job_service +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import study +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + +from .transports.base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import JobServiceGrpcTransport +from .transports.grpc_asyncio import JobServiceGrpcAsyncIOTransport + + +class JobServiceClientMeta(type): + """Metaclass for the JobService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] + _transport_registry["grpc"] = JobServiceGrpcTransport + _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[JobServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class JobServiceClient(metaclass=JobServiceClientMeta): + """A service for creating and managing AI Platform's jobs.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> JobServiceTransport: + """Return the transport used by the client instance. + + Returns: + JobServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def batch_prediction_job_path( + project: str, location: str, batch_prediction_job: str, + ) -> str: + """Return a fully-qualified batch_prediction_job string.""" + return "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, + location=location, + batch_prediction_job=batch_prediction_job, + ) + + @staticmethod + def parse_batch_prediction_job_path(path: str) -> Dict[str, str]: + """Parse a batch_prediction_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/batchPredictionJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def custom_job_path(project: str, location: str, custom_job: str,) -> str: + """Return a fully-qualified custom_job string.""" + return "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + + @staticmethod + def parse_custom_job_path(path: str) -> Dict[str, str]: + """Parse a custom_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/customJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def data_labeling_job_path( + project: str, location: str, data_labeling_job: str, + ) -> str: + """Return a fully-qualified data_labeling_job string.""" + return "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + + @staticmethod + def parse_data_labeling_job_path(path: str) -> Dict[str, str]: + """Parse a data_labeling_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/dataLabelingJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def hyperparameter_tuning_job_path( + project: str, location: str, hyperparameter_tuning_job: str, + ) -> str: + """Return a fully-qualified hyperparameter_tuning_job string.""" + return "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + + @staticmethod + def parse_hyperparameter_tuning_job_path(path: str) -> Dict[str, str]: + """Parse a hyperparameter_tuning_job path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/hyperparameterTuningJobs/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def trial_path(project: str, location: str, study: str, trial: str,) -> str: + """Return a fully-qualified trial string.""" + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + + @staticmethod + def parse_trial_path(path: str) -> Dict[str, str]: + """Parse a trial path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, JobServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the job service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, JobServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, JobServiceTransport): + # transport is a JobServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_custom_job( + self, + request: job_service.CreateCustomJobRequest = None, + *, + parent: str = None, + custom_job: gca_custom_job.CustomJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_custom_job.CustomJob: + r"""Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Args: + request (google.cloud.aiplatform_v1.types.CreateCustomJobRequest): + The request object. Request message for + ``JobService.CreateCustomJob``. + parent (str): + Required. The resource name of the Location to create + the CustomJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + custom_job (google.cloud.aiplatform_v1.types.CustomJob): + Required. The CustomJob to create. + This corresponds to the ``custom_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, custom_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateCustomJobRequest): + request = job_service.CreateCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if custom_job is not None: + request.custom_job = custom_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_custom_job( + self, + request: job_service.GetCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> custom_job.CustomJob: + r"""Gets a CustomJob. + + Args: + request (google.cloud.aiplatform_v1.types.GetCustomJobRequest): + The request object. Request message for + ``JobService.GetCustomJob``. + name (str): + Required. The name of the CustomJob resource. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.CustomJob: + Represents a job that runs custom + workloads such as a Docker container or + a Python package. A CustomJob can have + multiple worker pools and each worker + pool can have its own machine and input + spec. A CustomJob will be cleaned up + once the job enters terminal state + (failed or succeeded). + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetCustomJobRequest): + request = job_service.GetCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_custom_jobs( + self, + request: job_service.ListCustomJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListCustomJobsPager: + r"""Lists CustomJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListCustomJobsRequest): + The request object. Request message for + ``JobService.ListCustomJobs``. + parent (str): + Required. The resource name of the Location to list the + CustomJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListCustomJobsPager: + Response message for + ``JobService.ListCustomJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListCustomJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListCustomJobsRequest): + request = job_service.ListCustomJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_custom_jobs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListCustomJobsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_custom_job( + self, + request: job_service.DeleteCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a CustomJob. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteCustomJobRequest): + The request object. Request message for + ``JobService.DeleteCustomJob``. + name (str): + Required. The name of the CustomJob resource to be + deleted. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteCustomJobRequest): + request = job_service.DeleteCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_custom_job( + self, + request: job_service.CancelCustomJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` is + set to ``CANCELLED``. + + Args: + request (google.cloud.aiplatform_v1.types.CancelCustomJobRequest): + The request object. Request message for + ``JobService.CancelCustomJob``. + name (str): + Required. The name of the CustomJob to cancel. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelCustomJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelCustomJobRequest): + request = job_service.CancelCustomJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_custom_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_data_labeling_job( + self, + request: job_service.CreateDataLabelingJobRequest = None, + *, + parent: str = None, + data_labeling_job: gca_data_labeling_job.DataLabelingJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_data_labeling_job.DataLabelingJob: + r"""Creates a DataLabelingJob. + + Args: + request (google.cloud.aiplatform_v1.types.CreateDataLabelingJobRequest): + The request object. Request message for + [DataLabelingJobService.CreateDataLabelingJob][]. + parent (str): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + data_labeling_job (google.cloud.aiplatform_v1.types.DataLabelingJob): + Required. The DataLabelingJob to + create. + + This corresponds to the ``data_labeling_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, data_labeling_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateDataLabelingJobRequest): + request = job_service.CreateDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if data_labeling_job is not None: + request.data_labeling_job = data_labeling_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_data_labeling_job( + self, + request: job_service.GetDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> data_labeling_job.DataLabelingJob: + r"""Gets a DataLabelingJob. + + Args: + request (google.cloud.aiplatform_v1.types.GetDataLabelingJobRequest): + The request object. Request message for + [DataLabelingJobService.GetDataLabelingJob][]. + name (str): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.DataLabelingJob: + DataLabelingJob is used to trigger a + human labeling job on unlabeled data + from the following Dataset: + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetDataLabelingJobRequest): + request = job_service.GetDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_data_labeling_jobs( + self, + request: job_service.ListDataLabelingJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListDataLabelingJobsPager: + r"""Lists DataLabelingJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListDataLabelingJobsRequest): + The request object. Request message for + [DataLabelingJobService.ListDataLabelingJobs][]. + parent (str): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListDataLabelingJobsPager: + Response message for + ``JobService.ListDataLabelingJobs``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListDataLabelingJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListDataLabelingJobsRequest): + request = job_service.ListDataLabelingJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_data_labeling_jobs] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListDataLabelingJobsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_data_labeling_job( + self, + request: job_service.DeleteDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a DataLabelingJob. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteDataLabelingJobRequest): + The request object. Request message for + ``JobService.DeleteDataLabelingJob``. + name (str): + Required. The name of the DataLabelingJob to be deleted. + Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteDataLabelingJobRequest): + request = job_service.DeleteDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_data_labeling_job( + self, + request: job_service.CancelDataLabelingJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Args: + request (google.cloud.aiplatform_v1.types.CancelDataLabelingJobRequest): + The request object. Request message for + [DataLabelingJobService.CancelDataLabelingJob][]. + name (str): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelDataLabelingJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelDataLabelingJobRequest): + request = job_service.CancelDataLabelingJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_data_labeling_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_hyperparameter_tuning_job( + self, + request: job_service.CreateHyperparameterTuningJobRequest = None, + *, + parent: str = None, + hyperparameter_tuning_job: gca_hyperparameter_tuning_job.HyperparameterTuningJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Creates a HyperparameterTuningJob + + Args: + request (google.cloud.aiplatform_v1.types.CreateHyperparameterTuningJobRequest): + The request object. Request message for + ``JobService.CreateHyperparameterTuningJob``. + parent (str): + Required. The resource name of the Location to create + the HyperparameterTuningJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + hyperparameter_tuning_job (google.cloud.aiplatform_v1.types.HyperparameterTuningJob): + Required. The HyperparameterTuningJob + to create. + + This corresponds to the ``hyperparameter_tuning_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, hyperparameter_tuning_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateHyperparameterTuningJobRequest): + request = job_service.CreateHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if hyperparameter_tuning_job is not None: + request.hyperparameter_tuning_job = hyperparameter_tuning_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.create_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_hyperparameter_tuning_job( + self, + request: job_service.GetHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> hyperparameter_tuning_job.HyperparameterTuningJob: + r"""Gets a HyperparameterTuningJob + + Args: + request (google.cloud.aiplatform_v1.types.GetHyperparameterTuningJobRequest): + The request object. Request message for + ``JobService.GetHyperparameterTuningJob``. + name (str): + Required. The name of the HyperparameterTuningJob + resource. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.HyperparameterTuningJob: + Represents a HyperparameterTuningJob. + A HyperparameterTuningJob has a Study + specification and multiple CustomJobs + with identical CustomJob specification. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetHyperparameterTuningJobRequest): + request = job_service.GetHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.get_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_hyperparameter_tuning_jobs( + self, + request: job_service.ListHyperparameterTuningJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListHyperparameterTuningJobsPager: + r"""Lists HyperparameterTuningJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsRequest): + The request object. Request message for + ``JobService.ListHyperparameterTuningJobs``. + parent (str): + Required. The resource name of the Location to list the + HyperparameterTuningJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListHyperparameterTuningJobsPager: + Response message for + ``JobService.ListHyperparameterTuningJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListHyperparameterTuningJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListHyperparameterTuningJobsRequest): + request = job_service.ListHyperparameterTuningJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_hyperparameter_tuning_jobs + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListHyperparameterTuningJobsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_hyperparameter_tuning_job( + self, + request: job_service.DeleteHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a HyperparameterTuningJob. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteHyperparameterTuningJobRequest): + The request object. Request message for + ``JobService.DeleteHyperparameterTuningJob``. + name (str): + Required. The name of the HyperparameterTuningJob + resource to be deleted. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteHyperparameterTuningJobRequest): + request = job_service.DeleteHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.delete_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_hyperparameter_tuning_job( + self, + request: job_service.CancelHyperparameterTuningJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Args: + request (google.cloud.aiplatform_v1.types.CancelHyperparameterTuningJobRequest): + The request object. Request message for + ``JobService.CancelHyperparameterTuningJob``. + name (str): + Required. The name of the HyperparameterTuningJob to + cancel. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelHyperparameterTuningJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelHyperparameterTuningJobRequest): + request = job_service.CancelHyperparameterTuningJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.cancel_hyperparameter_tuning_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + def create_batch_prediction_job( + self, + request: job_service.CreateBatchPredictionJobRequest = None, + *, + parent: str = None, + batch_prediction_job: gca_batch_prediction_job.BatchPredictionJob = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_batch_prediction_job.BatchPredictionJob: + r"""Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Args: + request (google.cloud.aiplatform_v1.types.CreateBatchPredictionJobRequest): + The request object. Request message for + ``JobService.CreateBatchPredictionJob``. + parent (str): + Required. The resource name of the Location to create + the BatchPredictionJob in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + batch_prediction_job (google.cloud.aiplatform_v1.types.BatchPredictionJob): + Required. The BatchPredictionJob to + create. + + This corresponds to the ``batch_prediction_job`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, batch_prediction_job]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CreateBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CreateBatchPredictionJobRequest): + request = job_service.CreateBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if batch_prediction_job is not None: + request.batch_prediction_job = batch_prediction_job + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.create_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_batch_prediction_job( + self, + request: job_service.GetBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> batch_prediction_job.BatchPredictionJob: + r"""Gets a BatchPredictionJob + + Args: + request (google.cloud.aiplatform_v1.types.GetBatchPredictionJobRequest): + The request object. Request message for + ``JobService.GetBatchPredictionJob``. + name (str): + Required. The name of the BatchPredictionJob resource. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.GetBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.GetBatchPredictionJobRequest): + request = job_service.GetBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_batch_prediction_job] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_batch_prediction_jobs( + self, + request: job_service.ListBatchPredictionJobsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListBatchPredictionJobsPager: + r"""Lists BatchPredictionJobs in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsRequest): + The request object. Request message for + ``JobService.ListBatchPredictionJobs``. + parent (str): + Required. The resource name of the Location to list the + BatchPredictionJobs from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.job_service.pagers.ListBatchPredictionJobsPager: + Response message for + ``JobService.ListBatchPredictionJobs`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.ListBatchPredictionJobsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.ListBatchPredictionJobsRequest): + request = job_service.ListBatchPredictionJobsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_batch_prediction_jobs + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListBatchPredictionJobsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_batch_prediction_job( + self, + request: job_service.DeleteBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteBatchPredictionJobRequest): + The request object. Request message for + ``JobService.DeleteBatchPredictionJob``. + name (str): + Required. The name of the BatchPredictionJob resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.DeleteBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.DeleteBatchPredictionJobRequest): + request = job_service.DeleteBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.delete_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_batch_prediction_job( + self, + request: job_service.CancelBatchPredictionJobRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Args: + request (google.cloud.aiplatform_v1.types.CancelBatchPredictionJobRequest): + The request object. Request message for + ``JobService.CancelBatchPredictionJob``. + name (str): + Required. The name of the BatchPredictionJob to cancel. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a job_service.CancelBatchPredictionJobRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, job_service.CancelBatchPredictionJobRequest): + request = job_service.CancelBatchPredictionJobRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.cancel_batch_prediction_job + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("JobServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/job_service/pagers.py b/google/cloud/aiplatform_v1/services/job_service/pagers.py new file mode 100644 index 0000000000..b5a0f4b929 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/pagers.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import job_service + + +class ListCustomJobsPager: + """A pager for iterating through ``list_custom_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListCustomJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``custom_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListCustomJobs`` requests and continue to iterate + through the ``custom_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListCustomJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., job_service.ListCustomJobsResponse], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListCustomJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListCustomJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListCustomJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.ListCustomJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[custom_job.CustomJob]: + for page in self.pages: + yield from page.custom_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListCustomJobsAsyncPager: + """A pager for iterating through ``list_custom_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListCustomJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``custom_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListCustomJobs`` requests and continue to iterate + through the ``custom_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListCustomJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListCustomJobsResponse]], + request: job_service.ListCustomJobsRequest, + response: job_service.ListCustomJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListCustomJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListCustomJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListCustomJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListCustomJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[custom_job.CustomJob]: + async def async_generator(): + async for page in self.pages: + for response in page.custom_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDataLabelingJobsPager: + """A pager for iterating through ``list_data_labeling_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``data_labeling_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListDataLabelingJobs`` requests and continue to iterate + through the ``data_labeling_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., job_service.ListDataLabelingJobsResponse], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDataLabelingJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListDataLabelingJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.ListDataLabelingJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[data_labeling_job.DataLabelingJob]: + for page in self.pages: + yield from page.data_labeling_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDataLabelingJobsAsyncPager: + """A pager for iterating through ``list_data_labeling_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``data_labeling_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDataLabelingJobs`` requests and continue to iterate + through the ``data_labeling_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListDataLabelingJobsResponse]], + request: job_service.ListDataLabelingJobsRequest, + response: job_service.ListDataLabelingJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListDataLabelingJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListDataLabelingJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListDataLabelingJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListDataLabelingJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[data_labeling_job.DataLabelingJob]: + async def async_generator(): + async for page in self.pages: + for response in page.data_labeling_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListHyperparameterTuningJobsPager: + """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``hyperparameter_tuning_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListHyperparameterTuningJobs`` requests and continue to iterate + through the ``hyperparameter_tuning_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., job_service.ListHyperparameterTuningJobsResponse], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListHyperparameterTuningJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.ListHyperparameterTuningJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + for page in self.pages: + yield from page.hyperparameter_tuning_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListHyperparameterTuningJobsAsyncPager: + """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``hyperparameter_tuning_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListHyperparameterTuningJobs`` requests and continue to iterate + through the ``hyperparameter_tuning_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[job_service.ListHyperparameterTuningJobsResponse] + ], + request: job_service.ListHyperparameterTuningJobsRequest, + response: job_service.ListHyperparameterTuningJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListHyperparameterTuningJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListHyperparameterTuningJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[job_service.ListHyperparameterTuningJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__( + self, + ) -> AsyncIterable[hyperparameter_tuning_job.HyperparameterTuningJob]: + async def async_generator(): + async for page in self.pages: + for response in page.hyperparameter_tuning_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListBatchPredictionJobsPager: + """A pager for iterating through ``list_batch_prediction_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``batch_prediction_jobs`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListBatchPredictionJobs`` requests and continue to iterate + through the ``batch_prediction_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., job_service.ListBatchPredictionJobsResponse], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListBatchPredictionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[job_service.ListBatchPredictionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[batch_prediction_job.BatchPredictionJob]: + for page in self.pages: + yield from page.batch_prediction_jobs + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListBatchPredictionJobsAsyncPager: + """A pager for iterating through ``list_batch_prediction_jobs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``batch_prediction_jobs`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListBatchPredictionJobs`` requests and continue to iterate + through the ``batch_prediction_jobs`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[job_service.ListBatchPredictionJobsResponse]], + request: job_service.ListBatchPredictionJobsRequest, + response: job_service.ListBatchPredictionJobsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListBatchPredictionJobsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = job_service.ListBatchPredictionJobsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[job_service.ListBatchPredictionJobsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[batch_prediction_job.BatchPredictionJob]: + async def async_generator(): + async for page in self.pages: + for response in page.batch_prediction_jobs: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py new file mode 100644 index 0000000000..349bfbcdea --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import JobServiceTransport +from .grpc import JobServiceGrpcTransport +from .grpc_asyncio import JobServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[JobServiceTransport]] +_transport_registry["grpc"] = JobServiceGrpcTransport +_transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport + +__all__ = ( + "JobServiceTransport", + "JobServiceGrpcTransport", + "JobServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1/services/job_service/transports/base.py new file mode 100644 index 0000000000..42ab8e1688 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/transports/base.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import job_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class JobServiceTransport(abc.ABC): + """Abstract transport class for JobService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_custom_job: gapic_v1.method.wrap_method( + self.create_custom_job, default_timeout=None, client_info=client_info, + ), + self.get_custom_job: gapic_v1.method.wrap_method( + self.get_custom_job, default_timeout=None, client_info=client_info, + ), + self.list_custom_jobs: gapic_v1.method.wrap_method( + self.list_custom_jobs, default_timeout=None, client_info=client_info, + ), + self.delete_custom_job: gapic_v1.method.wrap_method( + self.delete_custom_job, default_timeout=None, client_info=client_info, + ), + self.cancel_custom_job: gapic_v1.method.wrap_method( + self.cancel_custom_job, default_timeout=None, client_info=client_info, + ), + self.create_data_labeling_job: gapic_v1.method.wrap_method( + self.create_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.get_data_labeling_job: gapic_v1.method.wrap_method( + self.get_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.list_data_labeling_jobs: gapic_v1.method.wrap_method( + self.list_data_labeling_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_data_labeling_job: gapic_v1.method.wrap_method( + self.delete_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_data_labeling_job: gapic_v1.method.wrap_method( + self.cancel_data_labeling_job, + default_timeout=None, + client_info=client_info, + ), + self.create_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.create_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.get_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.get_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.list_hyperparameter_tuning_jobs: gapic_v1.method.wrap_method( + self.list_hyperparameter_tuning_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.delete_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_hyperparameter_tuning_job: gapic_v1.method.wrap_method( + self.cancel_hyperparameter_tuning_job, + default_timeout=None, + client_info=client_info, + ), + self.create_batch_prediction_job: gapic_v1.method.wrap_method( + self.create_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.get_batch_prediction_job: gapic_v1.method.wrap_method( + self.get_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.list_batch_prediction_jobs: gapic_v1.method.wrap_method( + self.list_batch_prediction_jobs, + default_timeout=None, + client_info=client_info, + ), + self.delete_batch_prediction_job: gapic_v1.method.wrap_method( + self.delete_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + self.cancel_batch_prediction_job: gapic_v1.method.wrap_method( + self.cancel_batch_prediction_job, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_custom_job( + self, + ) -> typing.Callable[ + [job_service.CreateCustomJobRequest], + typing.Union[ + gca_custom_job.CustomJob, typing.Awaitable[gca_custom_job.CustomJob] + ], + ]: + raise NotImplementedError() + + @property + def get_custom_job( + self, + ) -> typing.Callable[ + [job_service.GetCustomJobRequest], + typing.Union[custom_job.CustomJob, typing.Awaitable[custom_job.CustomJob]], + ]: + raise NotImplementedError() + + @property + def list_custom_jobs( + self, + ) -> typing.Callable[ + [job_service.ListCustomJobsRequest], + typing.Union[ + job_service.ListCustomJobsResponse, + typing.Awaitable[job_service.ListCustomJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_custom_job( + self, + ) -> typing.Callable[ + [job_service.DeleteCustomJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_custom_job( + self, + ) -> typing.Callable[ + [job_service.CancelCustomJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + @property + def create_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CreateDataLabelingJobRequest], + typing.Union[ + gca_data_labeling_job.DataLabelingJob, + typing.Awaitable[gca_data_labeling_job.DataLabelingJob], + ], + ]: + raise NotImplementedError() + + @property + def get_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.GetDataLabelingJobRequest], + typing.Union[ + data_labeling_job.DataLabelingJob, + typing.Awaitable[data_labeling_job.DataLabelingJob], + ], + ]: + raise NotImplementedError() + + @property + def list_data_labeling_jobs( + self, + ) -> typing.Callable[ + [job_service.ListDataLabelingJobsRequest], + typing.Union[ + job_service.ListDataLabelingJobsResponse, + typing.Awaitable[job_service.ListDataLabelingJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.DeleteDataLabelingJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_data_labeling_job( + self, + ) -> typing.Callable[ + [job_service.CancelDataLabelingJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + @property + def create_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + typing.Union[ + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: + raise NotImplementedError() + + @property + def get_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.GetHyperparameterTuningJobRequest], + typing.Union[ + hyperparameter_tuning_job.HyperparameterTuningJob, + typing.Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ], + ]: + raise NotImplementedError() + + @property + def list_hyperparameter_tuning_jobs( + self, + ) -> typing.Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + typing.Union[ + job_service.ListHyperparameterTuningJobsResponse, + typing.Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_hyperparameter_tuning_job( + self, + ) -> typing.Callable[ + [job_service.CancelHyperparameterTuningJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + @property + def create_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CreateBatchPredictionJobRequest], + typing.Union[ + gca_batch_prediction_job.BatchPredictionJob, + typing.Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ], + ]: + raise NotImplementedError() + + @property + def get_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.GetBatchPredictionJobRequest], + typing.Union[ + batch_prediction_job.BatchPredictionJob, + typing.Awaitable[batch_prediction_job.BatchPredictionJob], + ], + ]: + raise NotImplementedError() + + @property + def list_batch_prediction_jobs( + self, + ) -> typing.Callable[ + [job_service.ListBatchPredictionJobsRequest], + typing.Union[ + job_service.ListBatchPredictionJobsResponse, + typing.Awaitable[job_service.ListBatchPredictionJobsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.DeleteBatchPredictionJobRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_batch_prediction_job( + self, + ) -> typing.Callable[ + [job_service.CancelBatchPredictionJobRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + +__all__ = ("JobServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py new file mode 100644 index 0000000000..139aaf3345 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc.py @@ -0,0 +1,886 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import job_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO + + +class JobServiceGrpcTransport(JobServiceTransport): + """gRPC backend transport for JobService. + + A service for creating and managing AI Platform's jobs. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_custom_job( + self, + ) -> Callable[[job_service.CreateCustomJobRequest], gca_custom_job.CustomJob]: + r"""Return a callable for the create custom job method over gRPC. + + Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Returns: + Callable[[~.CreateCustomJobRequest], + ~.CustomJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", + request_serializer=job_service.CreateCustomJobRequest.serialize, + response_deserializer=gca_custom_job.CustomJob.deserialize, + ) + return self._stubs["create_custom_job"] + + @property + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], custom_job.CustomJob]: + r"""Return a callable for the get custom job method over gRPC. + + Gets a CustomJob. + + Returns: + Callable[[~.GetCustomJobRequest], + ~.CustomJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetCustomJob", + request_serializer=job_service.GetCustomJobRequest.serialize, + response_deserializer=custom_job.CustomJob.deserialize, + ) + return self._stubs["get_custom_job"] + + @property + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], job_service.ListCustomJobsResponse + ]: + r"""Return a callable for the list custom jobs method over gRPC. + + Lists CustomJobs in a Location. + + Returns: + Callable[[~.ListCustomJobsRequest], + ~.ListCustomJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", + request_serializer=job_service.ListCustomJobsRequest.serialize, + response_deserializer=job_service.ListCustomJobsResponse.deserialize, + ) + return self._stubs["list_custom_jobs"] + + @property + def delete_custom_job( + self, + ) -> Callable[[job_service.DeleteCustomJobRequest], operations.Operation]: + r"""Return a callable for the delete custom job method over gRPC. + + Deletes a CustomJob. + + Returns: + Callable[[~.DeleteCustomJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", + request_serializer=job_service.DeleteCustomJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_custom_job"] + + @property + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], empty.Empty]: + r"""Return a callable for the cancel custom job method over gRPC. + + Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` is + set to ``CANCELLED``. + + Returns: + Callable[[~.CancelCustomJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", + request_serializer=job_service.CancelCustomJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_custom_job"] + + @property + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + gca_data_labeling_job.DataLabelingJob, + ]: + r"""Return a callable for the create data labeling job method over gRPC. + + Creates a DataLabelingJob. + + Returns: + Callable[[~.CreateDataLabelingJobRequest], + ~.DataLabelingJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", + request_serializer=job_service.CreateDataLabelingJobRequest.serialize, + response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["create_data_labeling_job"] + + @property + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], data_labeling_job.DataLabelingJob + ]: + r"""Return a callable for the get data labeling job method over gRPC. + + Gets a DataLabelingJob. + + Returns: + Callable[[~.GetDataLabelingJobRequest], + ~.DataLabelingJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", + request_serializer=job_service.GetDataLabelingJobRequest.serialize, + response_deserializer=data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["get_data_labeling_job"] + + @property + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + job_service.ListDataLabelingJobsResponse, + ]: + r"""Return a callable for the list data labeling jobs method over gRPC. + + Lists DataLabelingJobs in a Location. + + Returns: + Callable[[~.ListDataLabelingJobsRequest], + ~.ListDataLabelingJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", + request_serializer=job_service.ListDataLabelingJobsRequest.serialize, + response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, + ) + return self._stubs["list_data_labeling_jobs"] + + @property + def delete_data_labeling_job( + self, + ) -> Callable[[job_service.DeleteDataLabelingJobRequest], operations.Operation]: + r"""Return a callable for the delete data labeling job method over gRPC. + + Deletes a DataLabelingJob. + + Returns: + Callable[[~.DeleteDataLabelingJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", + request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_data_labeling_job"] + + @property + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], empty.Empty]: + r"""Return a callable for the cancel data labeling job method over gRPC. + + Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Returns: + Callable[[~.CancelDataLabelingJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", + request_serializer=job_service.CancelDataLabelingJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_data_labeling_job"] + + @property + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ]: + r"""Return a callable for the create hyperparameter tuning + job method over gRPC. + + Creates a HyperparameterTuningJob + + Returns: + Callable[[~.CreateHyperparameterTuningJobRequest], + ~.HyperparameterTuningJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", + request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, + response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["create_hyperparameter_tuning_job"] + + @property + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + hyperparameter_tuning_job.HyperparameterTuningJob, + ]: + r"""Return a callable for the get hyperparameter tuning job method over gRPC. + + Gets a HyperparameterTuningJob + + Returns: + Callable[[~.GetHyperparameterTuningJobRequest], + ~.HyperparameterTuningJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", + request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, + response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["get_hyperparameter_tuning_job"] + + @property + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + job_service.ListHyperparameterTuningJobsResponse, + ]: + r"""Return a callable for the list hyperparameter tuning + jobs method over gRPC. + + Lists HyperparameterTuningJobs in a Location. + + Returns: + Callable[[~.ListHyperparameterTuningJobsRequest], + ~.ListHyperparameterTuningJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", + request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, + response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, + ) + return self._stubs["list_hyperparameter_tuning_jobs"] + + @property + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], operations.Operation + ]: + r"""Return a callable for the delete hyperparameter tuning + job method over gRPC. + + Deletes a HyperparameterTuningJob. + + Returns: + Callable[[~.DeleteHyperparameterTuningJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", + request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_hyperparameter_tuning_job"] + + @property + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[[job_service.CancelHyperparameterTuningJobRequest], empty.Empty]: + r"""Return a callable for the cancel hyperparameter tuning + job method over gRPC. + + Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelHyperparameterTuningJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", + request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_hyperparameter_tuning_job"] + + @property + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + gca_batch_prediction_job.BatchPredictionJob, + ]: + r"""Return a callable for the create batch prediction job method over gRPC. + + Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Returns: + Callable[[~.CreateBatchPredictionJobRequest], + ~.BatchPredictionJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", + request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, + response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["create_batch_prediction_job"] + + @property + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + batch_prediction_job.BatchPredictionJob, + ]: + r"""Return a callable for the get batch prediction job method over gRPC. + + Gets a BatchPredictionJob + + Returns: + Callable[[~.GetBatchPredictionJobRequest], + ~.BatchPredictionJob]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", + request_serializer=job_service.GetBatchPredictionJobRequest.serialize, + response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["get_batch_prediction_job"] + + @property + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + job_service.ListBatchPredictionJobsResponse, + ]: + r"""Return a callable for the list batch prediction jobs method over gRPC. + + Lists BatchPredictionJobs in a Location. + + Returns: + Callable[[~.ListBatchPredictionJobsRequest], + ~.ListBatchPredictionJobsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", + request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, + response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, + ) + return self._stubs["list_batch_prediction_jobs"] + + @property + def delete_batch_prediction_job( + self, + ) -> Callable[[job_service.DeleteBatchPredictionJobRequest], operations.Operation]: + r"""Return a callable for the delete batch prediction job method over gRPC. + + Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Returns: + Callable[[~.DeleteBatchPredictionJobRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", + request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_batch_prediction_job"] + + @property + def cancel_batch_prediction_job( + self, + ) -> Callable[[job_service.CancelBatchPredictionJobRequest], empty.Empty]: + r"""Return a callable for the cancel batch prediction job method over gRPC. + + Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Returns: + Callable[[~.CancelBatchPredictionJobRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", + request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_batch_prediction_job"] + + +__all__ = ("JobServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..f056094c9d --- /dev/null +++ b/google/cloud/aiplatform_v1/services/job_service/transports/grpc_asyncio.py @@ -0,0 +1,907 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import job_service +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import JobServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import JobServiceGrpcTransport + + +class JobServiceGrpcAsyncIOTransport(JobServiceTransport): + """gRPC AsyncIO backend transport for JobService. + + A service for creating and managing AI Platform's jobs. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_custom_job( + self, + ) -> Callable[ + [job_service.CreateCustomJobRequest], Awaitable[gca_custom_job.CustomJob] + ]: + r"""Return a callable for the create custom job method over gRPC. + + Creates a CustomJob. A created CustomJob right away + will be attempted to be run. + + Returns: + Callable[[~.CreateCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_custom_job" not in self._stubs: + self._stubs["create_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateCustomJob", + request_serializer=job_service.CreateCustomJobRequest.serialize, + response_deserializer=gca_custom_job.CustomJob.deserialize, + ) + return self._stubs["create_custom_job"] + + @property + def get_custom_job( + self, + ) -> Callable[[job_service.GetCustomJobRequest], Awaitable[custom_job.CustomJob]]: + r"""Return a callable for the get custom job method over gRPC. + + Gets a CustomJob. + + Returns: + Callable[[~.GetCustomJobRequest], + Awaitable[~.CustomJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_custom_job" not in self._stubs: + self._stubs["get_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetCustomJob", + request_serializer=job_service.GetCustomJobRequest.serialize, + response_deserializer=custom_job.CustomJob.deserialize, + ) + return self._stubs["get_custom_job"] + + @property + def list_custom_jobs( + self, + ) -> Callable[ + [job_service.ListCustomJobsRequest], + Awaitable[job_service.ListCustomJobsResponse], + ]: + r"""Return a callable for the list custom jobs method over gRPC. + + Lists CustomJobs in a Location. + + Returns: + Callable[[~.ListCustomJobsRequest], + Awaitable[~.ListCustomJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_custom_jobs" not in self._stubs: + self._stubs["list_custom_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListCustomJobs", + request_serializer=job_service.ListCustomJobsRequest.serialize, + response_deserializer=job_service.ListCustomJobsResponse.deserialize, + ) + return self._stubs["list_custom_jobs"] + + @property + def delete_custom_job( + self, + ) -> Callable[ + [job_service.DeleteCustomJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete custom job method over gRPC. + + Deletes a CustomJob. + + Returns: + Callable[[~.DeleteCustomJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_custom_job" not in self._stubs: + self._stubs["delete_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteCustomJob", + request_serializer=job_service.DeleteCustomJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_custom_job"] + + @property + def cancel_custom_job( + self, + ) -> Callable[[job_service.CancelCustomJobRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the cancel custom job method over gRPC. + + Cancels a CustomJob. Starts asynchronous cancellation on the + CustomJob. The server makes a best effort to cancel the job, but + success is not guaranteed. Clients can use + ``JobService.GetCustomJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the CustomJob is not deleted; instead it becomes a + job with a + ``CustomJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``CustomJob.state`` is + set to ``CANCELLED``. + + Returns: + Callable[[~.CancelCustomJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_custom_job" not in self._stubs: + self._stubs["cancel_custom_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelCustomJob", + request_serializer=job_service.CancelCustomJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_custom_job"] + + @property + def create_data_labeling_job( + self, + ) -> Callable[ + [job_service.CreateDataLabelingJobRequest], + Awaitable[gca_data_labeling_job.DataLabelingJob], + ]: + r"""Return a callable for the create data labeling job method over gRPC. + + Creates a DataLabelingJob. + + Returns: + Callable[[~.CreateDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_data_labeling_job" not in self._stubs: + self._stubs["create_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateDataLabelingJob", + request_serializer=job_service.CreateDataLabelingJobRequest.serialize, + response_deserializer=gca_data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["create_data_labeling_job"] + + @property + def get_data_labeling_job( + self, + ) -> Callable[ + [job_service.GetDataLabelingJobRequest], + Awaitable[data_labeling_job.DataLabelingJob], + ]: + r"""Return a callable for the get data labeling job method over gRPC. + + Gets a DataLabelingJob. + + Returns: + Callable[[~.GetDataLabelingJobRequest], + Awaitable[~.DataLabelingJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_data_labeling_job" not in self._stubs: + self._stubs["get_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetDataLabelingJob", + request_serializer=job_service.GetDataLabelingJobRequest.serialize, + response_deserializer=data_labeling_job.DataLabelingJob.deserialize, + ) + return self._stubs["get_data_labeling_job"] + + @property + def list_data_labeling_jobs( + self, + ) -> Callable[ + [job_service.ListDataLabelingJobsRequest], + Awaitable[job_service.ListDataLabelingJobsResponse], + ]: + r"""Return a callable for the list data labeling jobs method over gRPC. + + Lists DataLabelingJobs in a Location. + + Returns: + Callable[[~.ListDataLabelingJobsRequest], + Awaitable[~.ListDataLabelingJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_data_labeling_jobs" not in self._stubs: + self._stubs["list_data_labeling_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListDataLabelingJobs", + request_serializer=job_service.ListDataLabelingJobsRequest.serialize, + response_deserializer=job_service.ListDataLabelingJobsResponse.deserialize, + ) + return self._stubs["list_data_labeling_jobs"] + + @property + def delete_data_labeling_job( + self, + ) -> Callable[ + [job_service.DeleteDataLabelingJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete data labeling job method over gRPC. + + Deletes a DataLabelingJob. + + Returns: + Callable[[~.DeleteDataLabelingJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_data_labeling_job" not in self._stubs: + self._stubs["delete_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteDataLabelingJob", + request_serializer=job_service.DeleteDataLabelingJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_data_labeling_job"] + + @property + def cancel_data_labeling_job( + self, + ) -> Callable[[job_service.CancelDataLabelingJobRequest], Awaitable[empty.Empty]]: + r"""Return a callable for the cancel data labeling job method over gRPC. + + Cancels a DataLabelingJob. Success of cancellation is + not guaranteed. + + Returns: + Callable[[~.CancelDataLabelingJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_data_labeling_job" not in self._stubs: + self._stubs["cancel_data_labeling_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelDataLabelingJob", + request_serializer=job_service.CancelDataLabelingJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_data_labeling_job"] + + @property + def create_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CreateHyperparameterTuningJobRequest], + Awaitable[gca_hyperparameter_tuning_job.HyperparameterTuningJob], + ]: + r"""Return a callable for the create hyperparameter tuning + job method over gRPC. + + Creates a HyperparameterTuningJob + + Returns: + Callable[[~.CreateHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "create_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateHyperparameterTuningJob", + request_serializer=job_service.CreateHyperparameterTuningJobRequest.serialize, + response_deserializer=gca_hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["create_hyperparameter_tuning_job"] + + @property + def get_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.GetHyperparameterTuningJobRequest], + Awaitable[hyperparameter_tuning_job.HyperparameterTuningJob], + ]: + r"""Return a callable for the get hyperparameter tuning job method over gRPC. + + Gets a HyperparameterTuningJob + + Returns: + Callable[[~.GetHyperparameterTuningJobRequest], + Awaitable[~.HyperparameterTuningJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "get_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetHyperparameterTuningJob", + request_serializer=job_service.GetHyperparameterTuningJobRequest.serialize, + response_deserializer=hyperparameter_tuning_job.HyperparameterTuningJob.deserialize, + ) + return self._stubs["get_hyperparameter_tuning_job"] + + @property + def list_hyperparameter_tuning_jobs( + self, + ) -> Callable[ + [job_service.ListHyperparameterTuningJobsRequest], + Awaitable[job_service.ListHyperparameterTuningJobsResponse], + ]: + r"""Return a callable for the list hyperparameter tuning + jobs method over gRPC. + + Lists HyperparameterTuningJobs in a Location. + + Returns: + Callable[[~.ListHyperparameterTuningJobsRequest], + Awaitable[~.ListHyperparameterTuningJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_hyperparameter_tuning_jobs" not in self._stubs: + self._stubs[ + "list_hyperparameter_tuning_jobs" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListHyperparameterTuningJobs", + request_serializer=job_service.ListHyperparameterTuningJobsRequest.serialize, + response_deserializer=job_service.ListHyperparameterTuningJobsResponse.deserialize, + ) + return self._stubs["list_hyperparameter_tuning_jobs"] + + @property + def delete_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.DeleteHyperparameterTuningJobRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete hyperparameter tuning + job method over gRPC. + + Deletes a HyperparameterTuningJob. + + Returns: + Callable[[~.DeleteHyperparameterTuningJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "delete_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteHyperparameterTuningJob", + request_serializer=job_service.DeleteHyperparameterTuningJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_hyperparameter_tuning_job"] + + @property + def cancel_hyperparameter_tuning_job( + self, + ) -> Callable[ + [job_service.CancelHyperparameterTuningJobRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel hyperparameter tuning + job method over gRPC. + + Cancels a HyperparameterTuningJob. Starts asynchronous + cancellation on the HyperparameterTuningJob. The server makes a + best effort to cancel the job, but success is not guaranteed. + Clients can use + ``JobService.GetHyperparameterTuningJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On successful + cancellation, the HyperparameterTuningJob is not deleted; + instead it becomes a job with a + ``HyperparameterTuningJob.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``HyperparameterTuningJob.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelHyperparameterTuningJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_hyperparameter_tuning_job" not in self._stubs: + self._stubs[ + "cancel_hyperparameter_tuning_job" + ] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelHyperparameterTuningJob", + request_serializer=job_service.CancelHyperparameterTuningJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_hyperparameter_tuning_job"] + + @property + def create_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CreateBatchPredictionJobRequest], + Awaitable[gca_batch_prediction_job.BatchPredictionJob], + ]: + r"""Return a callable for the create batch prediction job method over gRPC. + + Creates a BatchPredictionJob. A BatchPredictionJob + once created will right away be attempted to start. + + Returns: + Callable[[~.CreateBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_batch_prediction_job" not in self._stubs: + self._stubs["create_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CreateBatchPredictionJob", + request_serializer=job_service.CreateBatchPredictionJobRequest.serialize, + response_deserializer=gca_batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["create_batch_prediction_job"] + + @property + def get_batch_prediction_job( + self, + ) -> Callable[ + [job_service.GetBatchPredictionJobRequest], + Awaitable[batch_prediction_job.BatchPredictionJob], + ]: + r"""Return a callable for the get batch prediction job method over gRPC. + + Gets a BatchPredictionJob + + Returns: + Callable[[~.GetBatchPredictionJobRequest], + Awaitable[~.BatchPredictionJob]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_batch_prediction_job" not in self._stubs: + self._stubs["get_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/GetBatchPredictionJob", + request_serializer=job_service.GetBatchPredictionJobRequest.serialize, + response_deserializer=batch_prediction_job.BatchPredictionJob.deserialize, + ) + return self._stubs["get_batch_prediction_job"] + + @property + def list_batch_prediction_jobs( + self, + ) -> Callable[ + [job_service.ListBatchPredictionJobsRequest], + Awaitable[job_service.ListBatchPredictionJobsResponse], + ]: + r"""Return a callable for the list batch prediction jobs method over gRPC. + + Lists BatchPredictionJobs in a Location. + + Returns: + Callable[[~.ListBatchPredictionJobsRequest], + Awaitable[~.ListBatchPredictionJobsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_batch_prediction_jobs" not in self._stubs: + self._stubs["list_batch_prediction_jobs"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/ListBatchPredictionJobs", + request_serializer=job_service.ListBatchPredictionJobsRequest.serialize, + response_deserializer=job_service.ListBatchPredictionJobsResponse.deserialize, + ) + return self._stubs["list_batch_prediction_jobs"] + + @property + def delete_batch_prediction_job( + self, + ) -> Callable[ + [job_service.DeleteBatchPredictionJobRequest], Awaitable[operations.Operation] + ]: + r"""Return a callable for the delete batch prediction job method over gRPC. + + Deletes a BatchPredictionJob. Can only be called on + jobs that already finished. + + Returns: + Callable[[~.DeleteBatchPredictionJobRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_batch_prediction_job" not in self._stubs: + self._stubs["delete_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/DeleteBatchPredictionJob", + request_serializer=job_service.DeleteBatchPredictionJobRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_batch_prediction_job"] + + @property + def cancel_batch_prediction_job( + self, + ) -> Callable[ + [job_service.CancelBatchPredictionJobRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel batch prediction job method over gRPC. + + Cancels a BatchPredictionJob. + + Starts asynchronous cancellation on the BatchPredictionJob. The + server makes the best effort to cancel the job, but success is + not guaranteed. Clients can use + ``JobService.GetBatchPredictionJob`` + or other methods to check whether the cancellation succeeded or + whether the job completed despite cancellation. On a successful + cancellation, the BatchPredictionJob is not deleted;instead its + ``BatchPredictionJob.state`` + is set to ``CANCELLED``. Any files already outputted by the job + are not deleted. + + Returns: + Callable[[~.CancelBatchPredictionJobRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_batch_prediction_job" not in self._stubs: + self._stubs["cancel_batch_prediction_job"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.JobService/CancelBatchPredictionJob", + request_serializer=job_service.CancelBatchPredictionJobRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_batch_prediction_job"] + + +__all__ = ("JobServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/__init__.py new file mode 100644 index 0000000000..1d6216d1f7 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import MigrationServiceClient +from .async_client import MigrationServiceAsyncClient + +__all__ = ( + "MigrationServiceClient", + "MigrationServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1/services/migration_service/async_client.py new file mode 100644 index 0000000000..fcb1d23da7 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/async_client.py @@ -0,0 +1,367 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.migration_service import pagers +from google.cloud.aiplatform_v1.types import migratable_resource +from google.cloud.aiplatform_v1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport +from .client import MigrationServiceClient + + +class MigrationServiceAsyncClient: + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + _client: MigrationServiceClient + + DEFAULT_ENDPOINT = MigrationServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = MigrationServiceClient.DEFAULT_MTLS_ENDPOINT + + annotated_dataset_path = staticmethod(MigrationServiceClient.annotated_dataset_path) + parse_annotated_dataset_path = staticmethod( + MigrationServiceClient.parse_annotated_dataset_path + ) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + dataset_path = staticmethod(MigrationServiceClient.dataset_path) + parse_dataset_path = staticmethod(MigrationServiceClient.parse_dataset_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + model_path = staticmethod(MigrationServiceClient.model_path) + parse_model_path = staticmethod(MigrationServiceClient.parse_model_path) + version_path = staticmethod(MigrationServiceClient.version_path) + parse_version_path = staticmethod(MigrationServiceClient.parse_version_path) + + common_billing_account_path = staticmethod( + MigrationServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + MigrationServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(MigrationServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + MigrationServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + MigrationServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + MigrationServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(MigrationServiceClient.common_project_path) + parse_common_project_path = staticmethod( + MigrationServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(MigrationServiceClient.common_location_path) + parse_common_location_path = staticmethod( + MigrationServiceClient.parse_common_location_path + ) + + from_service_account_info = MigrationServiceClient.from_service_account_info + from_service_account_file = MigrationServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(MigrationServiceClient).get_transport_class, type(MigrationServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, MigrationServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = MigrationServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesAsyncPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.SearchMigratableResourcesRequest`): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (:class:`str`): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.migration_service.pagers.SearchMigratableResourcesAsyncPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.search_migratable_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.SearchMigratableResourcesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (:class:`google.cloud.aiplatform_v1.types.BatchMigrateResourcesRequest`): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (:class:`str`): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (:class:`Sequence[google.cloud.aiplatform_v1.types.MigrateResourceRequest]`): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.BatchMigrateResourcesResponse` + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.batch_migrate_resources, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("MigrationServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py new file mode 100644 index 0000000000..3ed18e0fa8 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -0,0 +1,654 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.migration_service import pagers +from google.cloud.aiplatform_v1.types import migratable_resource +from google.cloud.aiplatform_v1.types import migration_service + +from .transports.base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import MigrationServiceGrpcTransport +from .transports.grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +class MigrationServiceClientMeta(type): + """Metaclass for the MigrationService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[MigrationServiceTransport]] + _transport_registry["grpc"] = MigrationServiceGrpcTransport + _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[MigrationServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class MigrationServiceClient(metaclass=MigrationServiceClientMeta): + """A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> MigrationServiceTransport: + """Return the transport used by the client instance. + + Returns: + MigrationServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def annotated_dataset_path( + project: str, dataset: str, annotated_dataset: str, + ) -> str: + """Return a fully-qualified annotated_dataset string.""" + return "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + + @staticmethod + def parse_annotated_dataset_path(path: str) -> Dict[str, str]: + """Parse a annotated_dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/datasets/(?P.+?)/annotatedDatasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, location: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def dataset_path(project: str, dataset: str,) -> str: + """Return a fully-qualified dataset string.""" + return "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) + + @staticmethod + def parse_dataset_path(path: str) -> Dict[str, str]: + """Parse a dataset path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def version_path(project: str, model: str, version: str,) -> str: + """Return a fully-qualified version string.""" + return "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) + + @staticmethod + def parse_version_path(path: str) -> Dict[str, str]: + """Parse a version path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/models/(?P.+?)/versions/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, MigrationServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the migration service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, MigrationServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, MigrationServiceTransport): + # transport is a MigrationServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def search_migratable_resources( + self, + request: migration_service.SearchMigratableResourcesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.SearchMigratableResourcesPager: + r"""Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Args: + request (google.cloud.aiplatform_v1.types.SearchMigratableResourcesRequest): + The request object. Request message for + ``MigrationService.SearchMigratableResources``. + parent (str): + Required. The location that the migratable resources + should be searched from. It's the AI Platform location + that the resources can be migrated to, not the + resources' original location. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.migration_service.pagers.SearchMigratableResourcesPager: + Response message for + ``MigrationService.SearchMigratableResources``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.SearchMigratableResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.SearchMigratableResourcesRequest): + request = migration_service.SearchMigratableResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.search_migratable_resources + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.SearchMigratableResourcesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_migrate_resources( + self, + request: migration_service.BatchMigrateResourcesRequest = None, + *, + parent: str = None, + migrate_resource_requests: Sequence[ + migration_service.MigrateResourceRequest + ] = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation.Operation: + r"""Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Args: + request (google.cloud.aiplatform_v1.types.BatchMigrateResourcesRequest): + The request object. Request message for + ``MigrationService.BatchMigrateResources``. + parent (str): + Required. The location of the migrated resource will + live in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + migrate_resource_requests (Sequence[google.cloud.aiplatform_v1.types.MigrateResourceRequest]): + Required. The request messages + specifying the resources to migrate. + They must be in the same location as the + destination. Up to 50 resources can be + migrated in one batch. + + This corresponds to the ``migrate_resource_requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.BatchMigrateResourcesResponse` + Response message for + ``MigrationService.BatchMigrateResources``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, migrate_resource_requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a migration_service.BatchMigrateResourcesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, migration_service.BatchMigrateResourcesRequest): + request = migration_service.BatchMigrateResourcesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + if migrate_resource_requests: + request.migrate_resource_requests.extend(migrate_resource_requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_migrate_resources] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + migration_service.BatchMigrateResourcesResponse, + metadata_type=migration_service.BatchMigrateResourcesOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("MigrationServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1/services/migration_service/pagers.py new file mode 100644 index 0000000000..b7d9f4ae44 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import migratable_resource +from google.cloud.aiplatform_v1.types import migration_service + + +class SearchMigratableResourcesPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., migration_service.SearchMigratableResourcesResponse], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.SearchMigratableResourcesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[migratable_resource.MigratableResource]: + for page in self.pages: + yield from page.migratable_resources + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class SearchMigratableResourcesAsyncPager: + """A pager for iterating through ``search_migratable_resources`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``migratable_resources`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``SearchMigratableResources`` requests and continue to iterate + through the ``migratable_resources`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[migration_service.SearchMigratableResourcesResponse] + ], + request: migration_service.SearchMigratableResourcesRequest, + response: migration_service.SearchMigratableResourcesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.SearchMigratableResourcesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.SearchMigratableResourcesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = migration_service.SearchMigratableResourcesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[migration_service.SearchMigratableResourcesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[migratable_resource.MigratableResource]: + async def async_generator(): + async for page in self.pages: + for response in page.migratable_resources: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py new file mode 100644 index 0000000000..38c72756f6 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import MigrationServiceTransport +from .grpc import MigrationServiceGrpcTransport +from .grpc_asyncio import MigrationServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[MigrationServiceTransport]] +_transport_registry["grpc"] = MigrationServiceGrpcTransport +_transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport + +__all__ = ( + "MigrationServiceTransport", + "MigrationServiceGrpcTransport", + "MigrationServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py new file mode 100644 index 0000000000..da4cabae63 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/base.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class MigrationServiceTransport(abc.ABC): + """Abstract transport class for MigrationService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.search_migratable_resources: gapic_v1.method.wrap_method( + self.search_migratable_resources, + default_timeout=None, + client_info=client_info, + ), + self.batch_migrate_resources: gapic_v1.method.wrap_method( + self.batch_migrate_resources, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def search_migratable_resources( + self, + ) -> typing.Callable[ + [migration_service.SearchMigratableResourcesRequest], + typing.Union[ + migration_service.SearchMigratableResourcesResponse, + typing.Awaitable[migration_service.SearchMigratableResourcesResponse], + ], + ]: + raise NotImplementedError() + + @property + def batch_migrate_resources( + self, + ) -> typing.Callable[ + [migration_service.BatchMigrateResourcesRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + +__all__ = ("MigrationServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py new file mode 100644 index 0000000000..820a38a028 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc.py @@ -0,0 +1,333 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO + + +class MigrationServiceGrpcTransport(MigrationServiceTransport): + """gRPC backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + migration_service.SearchMigratableResourcesResponse, + ]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + ~.SearchMigratableResourcesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs["search_migratable_resources"] + + @property + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], operations.Operation + ]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["batch_migrate_resources"] + + +__all__ = ("MigrationServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..dbdddf31e5 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/migration_service/transports/grpc_asyncio.py @@ -0,0 +1,340 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import migration_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import MigrationServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import MigrationServiceGrpcTransport + + +class MigrationServiceGrpcAsyncIOTransport(MigrationServiceTransport): + """gRPC AsyncIO backend transport for MigrationService. + + A service that migrates resources from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def search_migratable_resources( + self, + ) -> Callable[ + [migration_service.SearchMigratableResourcesRequest], + Awaitable[migration_service.SearchMigratableResourcesResponse], + ]: + r"""Return a callable for the search migratable resources method over gRPC. + + Searches all of the resources in + automl.googleapis.com, datalabeling.googleapis.com and + ml.googleapis.com that can be migrated to AI Platform's + given location. + + Returns: + Callable[[~.SearchMigratableResourcesRequest], + Awaitable[~.SearchMigratableResourcesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "search_migratable_resources" not in self._stubs: + self._stubs["search_migratable_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/SearchMigratableResources", + request_serializer=migration_service.SearchMigratableResourcesRequest.serialize, + response_deserializer=migration_service.SearchMigratableResourcesResponse.deserialize, + ) + return self._stubs["search_migratable_resources"] + + @property + def batch_migrate_resources( + self, + ) -> Callable[ + [migration_service.BatchMigrateResourcesRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the batch migrate resources method over gRPC. + + Batch migrates resources from ml.googleapis.com, + automl.googleapis.com, and datalabeling.googleapis.com + to AI Platform (Unified). + + Returns: + Callable[[~.BatchMigrateResourcesRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_migrate_resources" not in self._stubs: + self._stubs["batch_migrate_resources"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.MigrationService/BatchMigrateResources", + request_serializer=migration_service.BatchMigrateResourcesRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["batch_migrate_resources"] + + +__all__ = ("MigrationServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/__init__.py b/google/cloud/aiplatform_v1/services/model_service/__init__.py new file mode 100644 index 0000000000..b39295ebfe --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import ModelServiceClient +from .async_client import ModelServiceAsyncClient + +__all__ = ( + "ModelServiceClient", + "ModelServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/model_service/async_client.py b/google/cloud/aiplatform_v1/services/model_service/async_client.py new file mode 100644 index 0000000000..123b922019 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/async_client.py @@ -0,0 +1,1031 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.model_service import pagers +from google.cloud.aiplatform_v1.types import deployed_model_ref +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .client import ModelServiceClient + + +class ModelServiceAsyncClient: + """A service for managing AI Platform's machine learning Models.""" + + _client: ModelServiceClient + + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(ModelServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(ModelServiceClient.parse_endpoint_path) + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + model_evaluation_path = staticmethod(ModelServiceClient.model_evaluation_path) + parse_model_evaluation_path = staticmethod( + ModelServiceClient.parse_model_evaluation_path + ) + model_evaluation_slice_path = staticmethod( + ModelServiceClient.model_evaluation_slice_path + ) + parse_model_evaluation_slice_path = staticmethod( + ModelServiceClient.parse_model_evaluation_slice_path + ) + training_pipeline_path = staticmethod(ModelServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod( + ModelServiceClient.parse_training_pipeline_path + ) + + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) + + from_service_account_info = ModelServiceClient.from_service_account_info + from_service_account_file = ModelServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(ModelServiceClient).get_transport_class, type(ModelServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, ModelServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Uploads a Model artifact into AI Platform. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UploadModelRequest`): + The request object. Request message for + ``ModelService.UploadModel``. + parent (:class:`str`): + Required. The resource name of the Location into which + to upload the Model. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model (:class:`google.cloud.aiplatform_v1.types.Model`): + Required. The Model to create. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.UploadModelResponse` + Response message of + ``ModelService.UploadModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.UploadModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model is not None: + request.model = model + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.upload_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.UploadModelResponse, + metadata_type=model_service.UploadModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets a Model. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetModelRequest`): + The request object. Request message for + ``ModelService.GetModel``. + name (:class:`str`): + Required. The name of the Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists Models in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListModelsRequest`): + The request object. Request message for + ``ModelService.ListModels``. + parent (:class:`str`): + Required. The resource name of the Location to list the + Models from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelsAsyncPager: + Response message for + ``ModelService.ListModels`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_models, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: + r"""Updates a Model. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UpdateModelRequest`): + The request object. Request message for + ``ModelService.UpdateModel``. + model (:class:`google.cloud.aiplatform_v1.types.Model`): + Required. The Model which replaces + the resource on the server. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.UpdateModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteModelRequest`): + The request object. Request message for + ``ModelService.DeleteModel``. + name (:class:`str`): + Required. The name of the Model resource to be deleted. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.DeleteModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1.Model.supported_export_formats]. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ExportModelRequest`): + The request object. Request message for + ``ModelService.ExportModel``. + name (:class:`str`): + Required. The resource name of the Model to export. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + output_config (:class:`google.cloud.aiplatform_v1.types.ExportModelRequest.OutputConfig`): + Required. The desired output location + and configuration. + + This corresponds to the ``output_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ExportModelResponse` + Response message of + ``ModelService.ExportModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ExportModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.export_model, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + model_service.ExportModelResponse, + metadata_type=model_service.ExportModelOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: + r"""Gets a ModelEvaluation. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetModelEvaluationRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluation``. + name (:class:`str`): + Required. The name of the ModelEvaluation resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.ModelEvaluation: + A collection of metrics calculated by + comparing Model's predictions on all of + the test data against annotations from + the test data. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelEvaluationRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsAsyncPager: + r"""Lists ModelEvaluations in a Model. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListModelEvaluationsRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluations``. + parent (:class:`str`): + Required. The resource name of the Model to list the + ModelEvaluations from. Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelEvaluationsAsyncPager: + Response message for + ``ModelService.ListModelEvaluations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelEvaluationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluations, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: + r"""Gets a ModelEvaluationSlice. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetModelEvaluationSliceRequest`): + The request object. Request message for + ``ModelService.GetModelEvaluationSlice``. + name (:class:`str`): + Required. The name of the ModelEvaluationSlice resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.ModelEvaluationSlice: + A collection of metrics calculated by + comparing Model's predictions on a slice + of the test data against ground truth + annotations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.GetModelEvaluationSliceRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_model_evaluation_slice, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesAsyncPager: + r"""Lists ModelEvaluationSlices in a ModelEvaluation. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesRequest`): + The request object. Request message for + ``ModelService.ListModelEvaluationSlices``. + parent (:class:`str`): + Required. The resource name of the ModelEvaluation to + list the ModelEvaluationSlices from. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelEvaluationSlicesAsyncPager: + Response message for + ``ModelService.ListModelEvaluationSlices``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = model_service.ListModelEvaluationSlicesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_model_evaluation_slices, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelEvaluationSlicesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("ModelServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/client.py b/google/cloud/aiplatform_v1/services/model_service/client.py new file mode 100644 index 0000000000..fa75f3c22b --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/client.py @@ -0,0 +1,1307 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.model_service import pagers +from google.cloud.aiplatform_v1.types import deployed_model_ref +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + +from .transports.base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport + + +class ModelServiceClientMeta(type): + """Metaclass for the ModelService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[ModelServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ModelServiceClient(metaclass=ModelServiceClientMeta): + """A service for managing AI Platform's machine learning Models.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Return the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_path( + project: str, location: str, model: str, evaluation: str, + ) -> str: + """Return a fully-qualified model_evaluation string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + + @staticmethod + def parse_model_evaluation_path(path: str) -> Dict[str, str]: + """Parse a model_evaluation path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_evaluation_slice_path( + project: str, location: str, model: str, evaluation: str, slice: str, + ) -> str: + """Return a fully-qualified model_evaluation_slice string.""" + return "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + + @staticmethod + def parse_model_evaluation_slice_path(path: str) -> Dict[str, str]: + """Parse a model_evaluation_slice path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)/evaluations/(?P.+?)/slices/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: + """Return a fully-qualified training_pipeline string.""" + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str, str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, ModelServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ModelServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, ModelServiceTransport): + # transport is a ModelServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def upload_model( + self, + request: model_service.UploadModelRequest = None, + *, + parent: str = None, + model: gca_model.Model = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Uploads a Model artifact into AI Platform. + + Args: + request (google.cloud.aiplatform_v1.types.UploadModelRequest): + The request object. Request message for + ``ModelService.UploadModel``. + parent (str): + Required. The resource name of the Location into which + to upload the Model. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + model (google.cloud.aiplatform_v1.types.Model): + Required. The Model to create. + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.UploadModelResponse` + Response message of + ``ModelService.UploadModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, model]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UploadModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UploadModelRequest): + request = model_service.UploadModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if model is not None: + request.model = model + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.upload_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + model_service.UploadModelResponse, + metadata_type=model_service.UploadModelOperationMetadata, + ) + + # Done; return the response. + return response + + def get_model( + self, + request: model_service.GetModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model.Model: + r"""Gets a Model. + + Args: + request (google.cloud.aiplatform_v1.types.GetModelRequest): + The request object. Request message for + ``ModelService.GetModel``. + name (str): + Required. The name of the Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_models( + self, + request: model_service.ListModelsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelsPager: + r"""Lists Models in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListModelsRequest): + The request object. Request message for + ``ModelService.ListModels``. + parent (str): + Required. The resource name of the Location to list the + Models from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelsPager: + Response message for + ``ModelService.ListModels`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_models] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def update_model( + self, + request: model_service.UpdateModelRequest = None, + *, + model: gca_model.Model = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_model.Model: + r"""Updates a Model. + + Args: + request (google.cloud.aiplatform_v1.types.UpdateModelRequest): + The request object. Request message for + ``ModelService.UpdateModel``. + model (google.cloud.aiplatform_v1.types.Model): + Required. The Model which replaces + the resource on the server. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For + the ``FieldMask`` definition, see + `FieldMask `__. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.Model: + A trained machine learning Model. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.UpdateModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.UpdateModelRequest): + request = model_service.UpdateModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if model is not None: + request.model = model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("model.name", request.model.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def delete_model( + self, + request: model_service.DeleteModelRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteModelRequest): + The request object. Request message for + ``ModelService.DeleteModel``. + name (str): + Required. The name of the Model resource to be deleted. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.DeleteModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.DeleteModelRequest): + request = model_service.DeleteModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def export_model( + self, + request: model_service.ExportModelRequest = None, + *, + name: str = None, + output_config: model_service.ExportModelRequest.OutputConfig = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1.Model.supported_export_formats]. + + Args: + request (google.cloud.aiplatform_v1.types.ExportModelRequest): + The request object. Request message for + ``ModelService.ExportModel``. + name (str): + Required. The resource name of the Model to export. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + output_config (google.cloud.aiplatform_v1.types.ExportModelRequest.OutputConfig): + Required. The desired output location + and configuration. + + This corresponds to the ``output_config`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.cloud.aiplatform_v1.types.ExportModelResponse` + Response message of + ``ModelService.ExportModel`` + operation. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name, output_config]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ExportModelRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ExportModelRequest): + request = model_service.ExportModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + if output_config is not None: + request.output_config = output_config + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.export_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + model_service.ExportModelResponse, + metadata_type=model_service.ExportModelOperationMetadata, + ) + + # Done; return the response. + return response + + def get_model_evaluation( + self, + request: model_service.GetModelEvaluationRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation.ModelEvaluation: + r"""Gets a ModelEvaluation. + + Args: + request (google.cloud.aiplatform_v1.types.GetModelEvaluationRequest): + The request object. Request message for + ``ModelService.GetModelEvaluation``. + name (str): + Required. The name of the ModelEvaluation resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.ModelEvaluation: + A collection of metrics calculated by + comparing Model's predictions on all of + the test data against annotations from + the test data. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationRequest): + request = model_service.GetModelEvaluationRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model_evaluation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_model_evaluations( + self, + request: model_service.ListModelEvaluationsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationsPager: + r"""Lists ModelEvaluations in a Model. + + Args: + request (google.cloud.aiplatform_v1.types.ListModelEvaluationsRequest): + The request object. Request message for + ``ModelService.ListModelEvaluations``. + parent (str): + Required. The resource name of the Model to list the + ModelEvaluations from. Format: + ``projects/{project}/locations/{location}/models/{model}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelEvaluationsPager: + Response message for + ``ModelService.ListModelEvaluations``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationsRequest): + request = model_service.ListModelEvaluationsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_model_evaluations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelEvaluationsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def get_model_evaluation_slice( + self, + request: model_service.GetModelEvaluationSliceRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> model_evaluation_slice.ModelEvaluationSlice: + r"""Gets a ModelEvaluationSlice. + + Args: + request (google.cloud.aiplatform_v1.types.GetModelEvaluationSliceRequest): + The request object. Request message for + ``ModelService.GetModelEvaluationSlice``. + name (str): + Required. The name of the ModelEvaluationSlice resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.ModelEvaluationSlice: + A collection of metrics calculated by + comparing Model's predictions on a slice + of the test data against ground truth + annotations. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.GetModelEvaluationSliceRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.GetModelEvaluationSliceRequest): + request = model_service.GetModelEvaluationSliceRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.get_model_evaluation_slice + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_model_evaluation_slices( + self, + request: model_service.ListModelEvaluationSlicesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListModelEvaluationSlicesPager: + r"""Lists ModelEvaluationSlices in a ModelEvaluation. + + Args: + request (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesRequest): + The request object. Request message for + ``ModelService.ListModelEvaluationSlices``. + parent (str): + Required. The resource name of the ModelEvaluation to + list the ModelEvaluationSlices from. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.model_service.pagers.ListModelEvaluationSlicesPager: + Response message for + ``ModelService.ListModelEvaluationSlices``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a model_service.ListModelEvaluationSlicesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, model_service.ListModelEvaluationSlicesRequest): + request = model_service.ListModelEvaluationSlicesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[ + self._transport.list_model_evaluation_slices + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelEvaluationSlicesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("ModelServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/model_service/pagers.py b/google/cloud/aiplatform_v1/services/model_service/pagers.py new file mode 100644 index 0000000000..be652f745f --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/pagers.py @@ -0,0 +1,411 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service + + +class ListModelsPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[model.Model]: + for page in self.pages: + yield from page.models + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelEvaluationsPager: + """A pager for iterating through ``list_model_evaluations`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``model_evaluations`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModelEvaluations`` requests and continue to iterate + through the ``model_evaluations`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationsResponse], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelEvaluationsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[model_service.ListModelEvaluationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[model_evaluation.ModelEvaluation]: + for page in self.pages: + yield from page.model_evaluations + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelEvaluationsAsyncPager: + """A pager for iterating through ``list_model_evaluations`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluations`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluations`` requests and continue to iterate + through the ``model_evaluations`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelEvaluationsResponse]], + request: model_service.ListModelEvaluationsRequest, + response: model_service.ListModelEvaluationsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelEvaluationsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelEvaluationsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterable[model_service.ListModelEvaluationsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation.ModelEvaluation]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluations: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelEvaluationSlicesPager: + """A pager for iterating through ``list_model_evaluation_slices`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``model_evaluation_slices`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModelEvaluationSlices`` requests and continue to iterate + through the ``model_evaluation_slices`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., model_service.ListModelEvaluationSlicesResponse], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationSlicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[model_service.ListModelEvaluationSlicesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[model_evaluation_slice.ModelEvaluationSlice]: + for page in self.pages: + yield from page.model_evaluation_slices + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelEvaluationSlicesAsyncPager: + """A pager for iterating through ``list_model_evaluation_slices`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``model_evaluation_slices`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModelEvaluationSlices`` requests and continue to iterate + through the ``model_evaluation_slices`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[model_service.ListModelEvaluationSlicesResponse] + ], + request: model_service.ListModelEvaluationSlicesRequest, + response: model_service.ListModelEvaluationSlicesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListModelEvaluationSlicesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = model_service.ListModelEvaluationSlicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[model_service.ListModelEvaluationSlicesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[model_evaluation_slice.ModelEvaluationSlice]: + async def async_generator(): + async for page in self.pages: + for response in page.model_evaluation_slices: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py new file mode 100644 index 0000000000..5d1cb51abc --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import ModelServiceTransport +from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + +__all__ = ( + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1/services/model_service/transports/base.py new file mode 100644 index 0000000000..d937f09a61 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/transports/base.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class ModelServiceTransport(abc.ABC): + """Abstract transport class for ModelService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.upload_model: gapic_v1.method.wrap_method( + self.upload_model, default_timeout=None, client_info=client_info, + ), + self.get_model: gapic_v1.method.wrap_method( + self.get_model, default_timeout=None, client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, default_timeout=None, client_info=client_info, + ), + self.update_model: gapic_v1.method.wrap_method( + self.update_model, default_timeout=None, client_info=client_info, + ), + self.delete_model: gapic_v1.method.wrap_method( + self.delete_model, default_timeout=None, client_info=client_info, + ), + self.export_model: gapic_v1.method.wrap_method( + self.export_model, default_timeout=None, client_info=client_info, + ), + self.get_model_evaluation: gapic_v1.method.wrap_method( + self.get_model_evaluation, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluations: gapic_v1.method.wrap_method( + self.list_model_evaluations, + default_timeout=None, + client_info=client_info, + ), + self.get_model_evaluation_slice: gapic_v1.method.wrap_method( + self.get_model_evaluation_slice, + default_timeout=None, + client_info=client_info, + ), + self.list_model_evaluation_slices: gapic_v1.method.wrap_method( + self.list_model_evaluation_slices, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def upload_model( + self, + ) -> typing.Callable[ + [model_service.UploadModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_model( + self, + ) -> typing.Callable[ + [model_service.GetModelRequest], + typing.Union[model.Model, typing.Awaitable[model.Model]], + ]: + raise NotImplementedError() + + @property + def list_models( + self, + ) -> typing.Callable[ + [model_service.ListModelsRequest], + typing.Union[ + model_service.ListModelsResponse, + typing.Awaitable[model_service.ListModelsResponse], + ], + ]: + raise NotImplementedError() + + @property + def update_model( + self, + ) -> typing.Callable[ + [model_service.UpdateModelRequest], + typing.Union[gca_model.Model, typing.Awaitable[gca_model.Model]], + ]: + raise NotImplementedError() + + @property + def delete_model( + self, + ) -> typing.Callable[ + [model_service.DeleteModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def export_model( + self, + ) -> typing.Callable[ + [model_service.ExportModelRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_model_evaluation( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationRequest], + typing.Union[ + model_evaluation.ModelEvaluation, + typing.Awaitable[model_evaluation.ModelEvaluation], + ], + ]: + raise NotImplementedError() + + @property + def list_model_evaluations( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationsRequest], + typing.Union[ + model_service.ListModelEvaluationsResponse, + typing.Awaitable[model_service.ListModelEvaluationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_model_evaluation_slice( + self, + ) -> typing.Callable[ + [model_service.GetModelEvaluationSliceRequest], + typing.Union[ + model_evaluation_slice.ModelEvaluationSlice, + typing.Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ], + ]: + raise NotImplementedError() + + @property + def list_model_evaluation_slices( + self, + ) -> typing.Callable[ + [model_service.ListModelEvaluationSlicesRequest], + typing.Union[ + model_service.ListModelEvaluationSlicesResponse, + typing.Awaitable[model_service.ListModelEvaluationSlicesResponse], + ], + ]: + raise NotImplementedError() + + +__all__ = ("ModelServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py new file mode 100644 index 0000000000..90dcfd008d --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc.py @@ -0,0 +1,547 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO + + +class ModelServiceGrpcTransport(ModelServiceTransport): + """gRPC backend transport for ModelService. + + A service for managing AI Platform's machine learning Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], operations.Operation]: + r"""Return a callable for the upload model method over gRPC. + + Uploads a Model artifact into AI Platform. + + Returns: + Callable[[~.UploadModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UploadModel", + request_serializer=model_service.UploadModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["upload_model"] + + @property + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + r"""Return a callable for the get model method over gRPC. + + Gets a Model. + + Returns: + Callable[[~.GetModelRequest], + ~.Model]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModel", + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs["get_model"] + + @property + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + r"""Return a callable for the list models method over gRPC. + + Lists Models in a Location. + + Returns: + Callable[[~.ListModelsRequest], + ~.ListModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModels", + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs["list_models"] + + @property + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], gca_model.Model]: + r"""Return a callable for the update model method over gRPC. + + Updates a Model. + + Returns: + Callable[[~.UpdateModelRequest], + ~.Model]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UpdateModel", + request_serializer=model_service.UpdateModelRequest.serialize, + response_deserializer=gca_model.Model.deserialize, + ) + return self._stubs["update_model"] + + @property + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], operations.Operation]: + r"""Return a callable for the delete model method over gRPC. + + Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Returns: + Callable[[~.DeleteModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/DeleteModel", + request_serializer=model_service.DeleteModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_model"] + + @property + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], operations.Operation]: + r"""Return a callable for the export model method over gRPC. + + Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1.Model.supported_export_formats]. + + Returns: + Callable[[~.ExportModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ExportModel", + request_serializer=model_service.ExportModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_model"] + + @property + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], model_evaluation.ModelEvaluation + ]: + r"""Return a callable for the get model evaluation method over gRPC. + + Gets a ModelEvaluation. + + Returns: + Callable[[~.GetModelEvaluationRequest], + ~.ModelEvaluation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", + request_serializer=model_service.GetModelEvaluationRequest.serialize, + response_deserializer=model_evaluation.ModelEvaluation.deserialize, + ) + return self._stubs["get_model_evaluation"] + + @property + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + model_service.ListModelEvaluationsResponse, + ]: + r"""Return a callable for the list model evaluations method over gRPC. + + Lists ModelEvaluations in a Model. + + Returns: + Callable[[~.ListModelEvaluationsRequest], + ~.ListModelEvaluationsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", + request_serializer=model_service.ListModelEvaluationsRequest.serialize, + response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, + ) + return self._stubs["list_model_evaluations"] + + @property + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + model_evaluation_slice.ModelEvaluationSlice, + ]: + r"""Return a callable for the get model evaluation slice method over gRPC. + + Gets a ModelEvaluationSlice. + + Returns: + Callable[[~.GetModelEvaluationSliceRequest], + ~.ModelEvaluationSlice]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", + request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, + response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, + ) + return self._stubs["get_model_evaluation_slice"] + + @property + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + model_service.ListModelEvaluationSlicesResponse, + ]: + r"""Return a callable for the list model evaluation slices method over gRPC. + + Lists ModelEvaluationSlices in a ModelEvaluation. + + Returns: + Callable[[~.ListModelEvaluationSlicesRequest], + ~.ListModelEvaluationSlicesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", + request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, + response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, + ) + return self._stubs["list_model_evaluation_slices"] + + +__all__ = ("ModelServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..2aeffea93f --- /dev/null +++ b/google/cloud/aiplatform_v1/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,558 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import ModelServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import ModelServiceGrpcTransport + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + A service for managing AI Platform's machine learning Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def upload_model( + self, + ) -> Callable[[model_service.UploadModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the upload model method over gRPC. + + Uploads a Model artifact into AI Platform. + + Returns: + Callable[[~.UploadModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "upload_model" not in self._stubs: + self._stubs["upload_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UploadModel", + request_serializer=model_service.UploadModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["upload_model"] + + @property + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets a Model. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model" not in self._stubs: + self._stubs["get_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModel", + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs["get_model"] + + @property + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: + r"""Return a callable for the list models method over gRPC. + + Lists Models in a Location. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_models" not in self._stubs: + self._stubs["list_models"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModels", + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs["list_models"] + + @property + def update_model( + self, + ) -> Callable[[model_service.UpdateModelRequest], Awaitable[gca_model.Model]]: + r"""Return a callable for the update model method over gRPC. + + Updates a Model. + + Returns: + Callable[[~.UpdateModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_model" not in self._stubs: + self._stubs["update_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/UpdateModel", + request_serializer=model_service.UpdateModelRequest.serialize, + response_deserializer=gca_model.Model.deserialize, + ) + return self._stubs["update_model"] + + @property + def delete_model( + self, + ) -> Callable[[model_service.DeleteModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the delete model method over gRPC. + + Deletes a Model. + Note: Model can only be deleted if there are no + DeployedModels created from it. + + Returns: + Callable[[~.DeleteModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_model" not in self._stubs: + self._stubs["delete_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/DeleteModel", + request_serializer=model_service.DeleteModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_model"] + + @property + def export_model( + self, + ) -> Callable[[model_service.ExportModelRequest], Awaitable[operations.Operation]]: + r"""Return a callable for the export model method over gRPC. + + Exports a trained, exportable, Model to a location specified by + the user. A Model is considered to be exportable if it has at + least one [supported export + format][google.cloud.aiplatform.v1.Model.supported_export_formats]. + + Returns: + Callable[[~.ExportModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "export_model" not in self._stubs: + self._stubs["export_model"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ExportModel", + request_serializer=model_service.ExportModelRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["export_model"] + + @property + def get_model_evaluation( + self, + ) -> Callable[ + [model_service.GetModelEvaluationRequest], + Awaitable[model_evaluation.ModelEvaluation], + ]: + r"""Return a callable for the get model evaluation method over gRPC. + + Gets a ModelEvaluation. + + Returns: + Callable[[~.GetModelEvaluationRequest], + Awaitable[~.ModelEvaluation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation" not in self._stubs: + self._stubs["get_model_evaluation"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluation", + request_serializer=model_service.GetModelEvaluationRequest.serialize, + response_deserializer=model_evaluation.ModelEvaluation.deserialize, + ) + return self._stubs["get_model_evaluation"] + + @property + def list_model_evaluations( + self, + ) -> Callable[ + [model_service.ListModelEvaluationsRequest], + Awaitable[model_service.ListModelEvaluationsResponse], + ]: + r"""Return a callable for the list model evaluations method over gRPC. + + Lists ModelEvaluations in a Model. + + Returns: + Callable[[~.ListModelEvaluationsRequest], + Awaitable[~.ListModelEvaluationsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluations" not in self._stubs: + self._stubs["list_model_evaluations"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluations", + request_serializer=model_service.ListModelEvaluationsRequest.serialize, + response_deserializer=model_service.ListModelEvaluationsResponse.deserialize, + ) + return self._stubs["list_model_evaluations"] + + @property + def get_model_evaluation_slice( + self, + ) -> Callable[ + [model_service.GetModelEvaluationSliceRequest], + Awaitable[model_evaluation_slice.ModelEvaluationSlice], + ]: + r"""Return a callable for the get model evaluation slice method over gRPC. + + Gets a ModelEvaluationSlice. + + Returns: + Callable[[~.GetModelEvaluationSliceRequest], + Awaitable[~.ModelEvaluationSlice]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model_evaluation_slice" not in self._stubs: + self._stubs["get_model_evaluation_slice"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/GetModelEvaluationSlice", + request_serializer=model_service.GetModelEvaluationSliceRequest.serialize, + response_deserializer=model_evaluation_slice.ModelEvaluationSlice.deserialize, + ) + return self._stubs["get_model_evaluation_slice"] + + @property + def list_model_evaluation_slices( + self, + ) -> Callable[ + [model_service.ListModelEvaluationSlicesRequest], + Awaitable[model_service.ListModelEvaluationSlicesResponse], + ]: + r"""Return a callable for the list model evaluation slices method over gRPC. + + Lists ModelEvaluationSlices in a ModelEvaluation. + + Returns: + Callable[[~.ListModelEvaluationSlicesRequest], + Awaitable[~.ListModelEvaluationSlicesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_model_evaluation_slices" not in self._stubs: + self._stubs["list_model_evaluation_slices"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.ModelService/ListModelEvaluationSlices", + request_serializer=model_service.ListModelEvaluationSlicesRequest.serialize, + response_deserializer=model_service.ListModelEvaluationSlicesResponse.deserialize, + ) + return self._stubs["list_model_evaluation_slices"] + + +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py new file mode 100644 index 0000000000..7f02b47358 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import PipelineServiceClient +from .async_client import PipelineServiceAsyncClient + +__all__ = ( + "PipelineServiceClient", + "PipelineServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py new file mode 100644 index 0000000000..95c7d8a176 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/async_client.py @@ -0,0 +1,600 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import pipeline_state +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport +from .client import PipelineServiceClient + + +class PipelineServiceAsyncClient: + """A service for creating and managing AI Platform's pipelines.""" + + _client: PipelineServiceClient + + DEFAULT_ENDPOINT = PipelineServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PipelineServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PipelineServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PipelineServiceClient.parse_endpoint_path) + model_path = staticmethod(PipelineServiceClient.model_path) + parse_model_path = staticmethod(PipelineServiceClient.parse_model_path) + training_pipeline_path = staticmethod(PipelineServiceClient.training_pipeline_path) + parse_training_pipeline_path = staticmethod( + PipelineServiceClient.parse_training_pipeline_path + ) + + common_billing_account_path = staticmethod( + PipelineServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PipelineServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(PipelineServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PipelineServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + PipelineServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PipelineServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(PipelineServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PipelineServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(PipelineServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PipelineServiceClient.parse_common_location_path + ) + + from_service_account_info = PipelineServiceClient.from_service_account_info + from_service_account_file = PipelineServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(PipelineServiceClient).get_transport_class, type(PipelineServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PipelineServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the pipeline service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PipelineServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PipelineServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: + r"""Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CreateTrainingPipeline``. + parent (:class:`str`): + Required. The resource name of the Location to create + the TrainingPipeline in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + training_pipeline (:class:`google.cloud.aiplatform_v1.types.TrainingPipeline`): + Required. The TrainingPipeline to + create. + + This corresponds to the ``training_pipeline`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CreateTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: + r"""Gets a TrainingPipeline. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.GetTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.GetTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesAsyncPager: + r"""Lists TrainingPipelines in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest`): + The request object. Request message for + ``PipelineService.ListTrainingPipelines``. + parent (:class:`str`): + Required. The resource name of the Location to list the + TrainingPipelines from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.pipeline_service.pagers.ListTrainingPipelinesAsyncPager: + Response message for + ``PipelineService.ListTrainingPipelines`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.ListTrainingPipelinesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_training_pipelines, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTrainingPipelinesAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a TrainingPipeline. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.DeleteTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.DeleteTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CancelTrainingPipelineRequest`): + The request object. Request message for + ``PipelineService.CancelTrainingPipeline``. + name (:class:`str`): + Required. The name of the TrainingPipeline to cancel. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = pipeline_service.CancelTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.cancel_training_pipeline, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + await rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PipelineServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1/services/pipeline_service/client.py new file mode 100644 index 0000000000..39f37eb72e --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/client.py @@ -0,0 +1,835 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import pipeline_state +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + +from .transports.base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import PipelineServiceGrpcTransport +from .transports.grpc_asyncio import PipelineServiceGrpcAsyncIOTransport + + +class PipelineServiceClientMeta(type): + """Metaclass for the PipelineService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PipelineServiceTransport]] + _transport_registry["grpc"] = PipelineServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + + def get_transport_class(cls, label: str = None,) -> Type[PipelineServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class PipelineServiceClient(metaclass=PipelineServiceClientMeta): + """A service for creating and managing AI Platform's pipelines.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PipelineServiceTransport: + """Return the transport used by the client instance. + + Returns: + PipelineServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def model_path(project: str, location: str, model: str,) -> str: + """Return a fully-qualified model string.""" + return "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parse a model path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/models/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def training_pipeline_path( + project: str, location: str, training_pipeline: str, + ) -> str: + """Return a fully-qualified training_pipeline string.""" + return "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + + @staticmethod + def parse_training_pipeline_path(path: str) -> Dict[str, str]: + """Parse a training_pipeline path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/trainingPipelines/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PipelineServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the pipeline service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, PipelineServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, PipelineServiceTransport): + # transport is a PipelineServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_training_pipeline( + self, + request: pipeline_service.CreateTrainingPipelineRequest = None, + *, + parent: str = None, + training_pipeline: gca_training_pipeline.TrainingPipeline = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> gca_training_pipeline.TrainingPipeline: + r"""Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Args: + request (google.cloud.aiplatform_v1.types.CreateTrainingPipelineRequest): + The request object. Request message for + ``PipelineService.CreateTrainingPipeline``. + parent (str): + Required. The resource name of the Location to create + the TrainingPipeline in. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + training_pipeline (google.cloud.aiplatform_v1.types.TrainingPipeline): + Required. The TrainingPipeline to + create. + + This corresponds to the ``training_pipeline`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, training_pipeline]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CreateTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CreateTrainingPipelineRequest): + request = pipeline_service.CreateTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if training_pipeline is not None: + request.training_pipeline = training_pipeline + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def get_training_pipeline( + self, + request: pipeline_service.GetTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> training_pipeline.TrainingPipeline: + r"""Gets a TrainingPipeline. + + Args: + request (google.cloud.aiplatform_v1.types.GetTrainingPipelineRequest): + The request object. Request message for + ``PipelineService.GetTrainingPipeline``. + name (str): + Required. The name of the TrainingPipeline resource. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.GetTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.GetTrainingPipelineRequest): + request = pipeline_service.GetTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_training_pipelines( + self, + request: pipeline_service.ListTrainingPipelinesRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListTrainingPipelinesPager: + r"""Lists TrainingPipelines in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest): + The request object. Request message for + ``PipelineService.ListTrainingPipelines``. + parent (str): + Required. The resource name of the Location to list the + TrainingPipelines from. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.pipeline_service.pagers.ListTrainingPipelinesPager: + Response message for + ``PipelineService.ListTrainingPipelines`` + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.ListTrainingPipelinesRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.ListTrainingPipelinesRequest): + request = pipeline_service.ListTrainingPipelinesRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_training_pipelines] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTrainingPipelinesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_training_pipeline( + self, + request: pipeline_service.DeleteTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a TrainingPipeline. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteTrainingPipelineRequest): + The request object. Request message for + ``PipelineService.DeleteTrainingPipeline``. + name (str): + Required. The name of the TrainingPipeline resource to + be deleted. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.DeleteTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.DeleteTrainingPipelineRequest): + request = pipeline_service.DeleteTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def cancel_training_pipeline( + self, + request: pipeline_service.CancelTrainingPipelineRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> None: + r"""Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Args: + request (google.cloud.aiplatform_v1.types.CancelTrainingPipelineRequest): + The request object. Request message for + ``PipelineService.CancelTrainingPipeline``. + name (str): + Required. The name of the TrainingPipeline to cancel. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a pipeline_service.CancelTrainingPipelineRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, pipeline_service.CancelTrainingPipelineRequest): + request = pipeline_service.CancelTrainingPipelineRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.cancel_training_pipeline] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + rpc( + request, retry=retry, timeout=timeout, metadata=metadata, + ) + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PipelineServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py new file mode 100644 index 0000000000..0f3503ff5a --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import training_pipeline + + +class ListTrainingPipelinesPager: + """A pager for iterating through ``list_training_pipelines`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``training_pipelines`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTrainingPipelines`` requests and continue to iterate + through the ``training_pipelines`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., pipeline_service.ListTrainingPipelinesResponse], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListTrainingPipelinesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[pipeline_service.ListTrainingPipelinesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[training_pipeline.TrainingPipeline]: + for page in self.pages: + yield from page.training_pipelines + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTrainingPipelinesAsyncPager: + """A pager for iterating through ``list_training_pipelines`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``training_pipelines`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTrainingPipelines`` requests and continue to iterate + through the ``training_pipelines`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[pipeline_service.ListTrainingPipelinesResponse] + ], + request: pipeline_service.ListTrainingPipelinesRequest, + response: pipeline_service.ListTrainingPipelinesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListTrainingPipelinesRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListTrainingPipelinesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = pipeline_service.ListTrainingPipelinesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[pipeline_service.ListTrainingPipelinesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[training_pipeline.TrainingPipeline]: + async def async_generator(): + async for page in self.pages: + for response in page.training_pipelines: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py new file mode 100644 index 0000000000..9d4610087a --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import PipelineServiceTransport +from .grpc import PipelineServiceGrpcTransport +from .grpc_asyncio import PipelineServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[PipelineServiceTransport]] +_transport_registry["grpc"] = PipelineServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport + +__all__ = ( + "PipelineServiceTransport", + "PipelineServiceGrpcTransport", + "PipelineServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py new file mode 100644 index 0000000000..e4bc8e66a8 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/base.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class PipelineServiceTransport(abc.ABC): + """Abstract transport class for PipelineService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_training_pipeline: gapic_v1.method.wrap_method( + self.create_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.get_training_pipeline: gapic_v1.method.wrap_method( + self.get_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.list_training_pipelines: gapic_v1.method.wrap_method( + self.list_training_pipelines, + default_timeout=None, + client_info=client_info, + ), + self.delete_training_pipeline: gapic_v1.method.wrap_method( + self.delete_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + self.cancel_training_pipeline: gapic_v1.method.wrap_method( + self.cancel_training_pipeline, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + typing.Union[ + gca_training_pipeline.TrainingPipeline, + typing.Awaitable[gca_training_pipeline.TrainingPipeline], + ], + ]: + raise NotImplementedError() + + @property + def get_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.GetTrainingPipelineRequest], + typing.Union[ + training_pipeline.TrainingPipeline, + typing.Awaitable[training_pipeline.TrainingPipeline], + ], + ]: + raise NotImplementedError() + + @property + def list_training_pipelines( + self, + ) -> typing.Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + typing.Union[ + pipeline_service.ListTrainingPipelinesResponse, + typing.Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def cancel_training_pipeline( + self, + ) -> typing.Callable[ + [pipeline_service.CancelTrainingPipelineRequest], + typing.Union[empty.Empty, typing.Awaitable[empty.Empty]], + ]: + raise NotImplementedError() + + +__all__ = ("PipelineServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py new file mode 100644 index 0000000000..818144f008 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc.py @@ -0,0 +1,426 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO + + +class PipelineServiceGrpcTransport(PipelineServiceTransport): + """gRPC backend transport for PipelineService. + + A service for creating and managing AI Platform's pipelines. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + gca_training_pipeline.TrainingPipeline, + ]: + r"""Return a callable for the create training pipeline method over gRPC. + + Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Returns: + Callable[[~.CreateTrainingPipelineRequest], + ~.TrainingPipeline]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", + request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, + response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["create_training_pipeline"] + + @property + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + training_pipeline.TrainingPipeline, + ]: + r"""Return a callable for the get training pipeline method over gRPC. + + Gets a TrainingPipeline. + + Returns: + Callable[[~.GetTrainingPipelineRequest], + ~.TrainingPipeline]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", + request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, + response_deserializer=training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["get_training_pipeline"] + + @property + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + pipeline_service.ListTrainingPipelinesResponse, + ]: + r"""Return a callable for the list training pipelines method over gRPC. + + Lists TrainingPipelines in a Location. + + Returns: + Callable[[~.ListTrainingPipelinesRequest], + ~.ListTrainingPipelinesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", + request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, + response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, + ) + return self._stubs["list_training_pipelines"] + + @property + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], operations.Operation + ]: + r"""Return a callable for the delete training pipeline method over gRPC. + + Deletes a TrainingPipeline. + + Returns: + Callable[[~.DeleteTrainingPipelineRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", + request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_training_pipeline"] + + @property + def cancel_training_pipeline( + self, + ) -> Callable[[pipeline_service.CancelTrainingPipelineRequest], empty.Empty]: + r"""Return a callable for the cancel training pipeline method over gRPC. + + Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelTrainingPipelineRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", + request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_training_pipeline"] + + +__all__ = ("PipelineServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..ceed94071f --- /dev/null +++ b/google/cloud/aiplatform_v1/services/pipeline_service/transports/grpc_asyncio.py @@ -0,0 +1,435 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 as operations # type: ignore +from google.protobuf import empty_pb2 as empty # type: ignore + +from .base import PipelineServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PipelineServiceGrpcTransport + + +class PipelineServiceGrpcAsyncIOTransport(PipelineServiceTransport): + """gRPC AsyncIO backend transport for PipelineService. + + A service for creating and managing AI Platform's pipelines. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CreateTrainingPipelineRequest], + Awaitable[gca_training_pipeline.TrainingPipeline], + ]: + r"""Return a callable for the create training pipeline method over gRPC. + + Creates a TrainingPipeline. A created + TrainingPipeline right away will be attempted to be run. + + Returns: + Callable[[~.CreateTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_training_pipeline" not in self._stubs: + self._stubs["create_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CreateTrainingPipeline", + request_serializer=pipeline_service.CreateTrainingPipelineRequest.serialize, + response_deserializer=gca_training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["create_training_pipeline"] + + @property + def get_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.GetTrainingPipelineRequest], + Awaitable[training_pipeline.TrainingPipeline], + ]: + r"""Return a callable for the get training pipeline method over gRPC. + + Gets a TrainingPipeline. + + Returns: + Callable[[~.GetTrainingPipelineRequest], + Awaitable[~.TrainingPipeline]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_training_pipeline" not in self._stubs: + self._stubs["get_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/GetTrainingPipeline", + request_serializer=pipeline_service.GetTrainingPipelineRequest.serialize, + response_deserializer=training_pipeline.TrainingPipeline.deserialize, + ) + return self._stubs["get_training_pipeline"] + + @property + def list_training_pipelines( + self, + ) -> Callable[ + [pipeline_service.ListTrainingPipelinesRequest], + Awaitable[pipeline_service.ListTrainingPipelinesResponse], + ]: + r"""Return a callable for the list training pipelines method over gRPC. + + Lists TrainingPipelines in a Location. + + Returns: + Callable[[~.ListTrainingPipelinesRequest], + Awaitable[~.ListTrainingPipelinesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_training_pipelines" not in self._stubs: + self._stubs["list_training_pipelines"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/ListTrainingPipelines", + request_serializer=pipeline_service.ListTrainingPipelinesRequest.serialize, + response_deserializer=pipeline_service.ListTrainingPipelinesResponse.deserialize, + ) + return self._stubs["list_training_pipelines"] + + @property + def delete_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.DeleteTrainingPipelineRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete training pipeline method over gRPC. + + Deletes a TrainingPipeline. + + Returns: + Callable[[~.DeleteTrainingPipelineRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_training_pipeline" not in self._stubs: + self._stubs["delete_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/DeleteTrainingPipeline", + request_serializer=pipeline_service.DeleteTrainingPipelineRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_training_pipeline"] + + @property + def cancel_training_pipeline( + self, + ) -> Callable[ + [pipeline_service.CancelTrainingPipelineRequest], Awaitable[empty.Empty] + ]: + r"""Return a callable for the cancel training pipeline method over gRPC. + + Cancels a TrainingPipeline. Starts asynchronous cancellation on + the TrainingPipeline. The server makes a best effort to cancel + the pipeline, but success is not guaranteed. Clients can use + ``PipelineService.GetTrainingPipeline`` + or other methods to check whether the cancellation succeeded or + whether the pipeline completed despite cancellation. On + successful cancellation, the TrainingPipeline is not deleted; + instead it becomes a pipeline with a + ``TrainingPipeline.error`` + value with a ``google.rpc.Status.code`` of + 1, corresponding to ``Code.CANCELLED``, and + ``TrainingPipeline.state`` + is set to ``CANCELLED``. + + Returns: + Callable[[~.CancelTrainingPipelineRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "cancel_training_pipeline" not in self._stubs: + self._stubs["cancel_training_pipeline"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PipelineService/CancelTrainingPipeline", + request_serializer=pipeline_service.CancelTrainingPipelineRequest.serialize, + response_deserializer=empty.Empty.FromString, + ) + return self._stubs["cancel_training_pipeline"] + + +__all__ = ("PipelineServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py new file mode 100644 index 0000000000..0c847693e0 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import PredictionServiceClient +from .async_client import PredictionServiceAsyncClient + +__all__ = ( + "PredictionServiceClient", + "PredictionServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py new file mode 100644 index 0000000000..c0ab09622c --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/async_client.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1.types import prediction_service +from google.protobuf import struct_pb2 as struct # type: ignore + +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport +from .client import PredictionServiceClient + + +class PredictionServiceAsyncClient: + """A service for online predictions and explanations.""" + + _client: PredictionServiceClient + + DEFAULT_ENDPOINT = PredictionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + + endpoint_path = staticmethod(PredictionServiceClient.endpoint_path) + parse_endpoint_path = staticmethod(PredictionServiceClient.parse_endpoint_path) + + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(PredictionServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod(PredictionServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) + + from_service_account_info = PredictionServiceClient.from_service_account_info + from_service_account_file = PredictionServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(PredictionServiceClient).get_transport_class, type(PredictionServiceClient) + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, PredictionServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the prediction service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.PredictionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = PredictionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: + r"""Perform an online prediction. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.PredictRequest`): + The request object. Request message for + ``PredictionService.Predict``. + endpoint (:class:`str`): + Required. The name of the Endpoint requested to serve + the prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`Sequence[google.protobuf.struct_pb2.Value]`): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``instance_schema_uri``. + + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (:class:`google.protobuf.struct_pb2.Value`): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``parameters_schema_uri``. + + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.PredictResponse: + Response message for + ``PredictionService.Predict``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = prediction_service.PredictRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + + if instances: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.predict, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PredictionServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/client.py b/google/cloud/aiplatform_v1/services/prediction_service/client.py new file mode 100644 index 0000000000..55c52b48f4 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/client.py @@ -0,0 +1,468 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.cloud.aiplatform_v1.types import prediction_service +from google.protobuf import struct_pb2 as struct # type: ignore + +from .transports.base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import PredictionServiceGrpcTransport +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport + + +class PredictionServiceClientMeta(type): + """Metaclass for the PredictionService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[PredictionServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class PredictionServiceClient(metaclass=PredictionServiceClientMeta): + """A service for online predictions and explanations.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PredictionServiceTransport: + """Return the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def endpoint_path(project: str, location: str, endpoint: str,) -> str: + """Return a fully-qualified endpoint string.""" + return "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + + @staticmethod + def parse_endpoint_path(path: str) -> Dict[str, str]: + """Parse a endpoint path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, PredictionServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the prediction service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, PredictionServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, PredictionServiceTransport): + # transport is a PredictionServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def predict( + self, + request: prediction_service.PredictRequest = None, + *, + endpoint: str = None, + instances: Sequence[struct.Value] = None, + parameters: struct.Value = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> prediction_service.PredictResponse: + r"""Perform an online prediction. + + Args: + request (google.cloud.aiplatform_v1.types.PredictRequest): + The request object. Request message for + ``PredictionService.Predict``. + endpoint (str): + Required. The name of the Endpoint requested to serve + the prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + + This corresponds to the ``endpoint`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (Sequence[google.protobuf.struct_pb2.Value]): + Required. The instances that are the input to the + prediction call. A DeployedModel may have an upper limit + on the number of instances it supports per request, and + when it is exceeded the prediction call errors in case + of AutoML Models, or, in case of customer created + Models, the behaviour is as documented by that Model. + The schema of any single instance may be specified via + Endpoint's DeployedModels' + [Model's][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``instance_schema_uri``. + + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + parameters (google.protobuf.struct_pb2.Value): + The parameters that govern the prediction. The schema of + the parameters may be specified via Endpoint's + DeployedModels' [Model's + ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``parameters_schema_uri``. + + This corresponds to the ``parameters`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.PredictResponse: + Response message for + ``PredictionService.Predict``. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([endpoint, instances, parameters]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a prediction_service.PredictRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if endpoint is not None: + request.endpoint = endpoint + if parameters is not None: + request.parameters = parameters + + if instances: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("endpoint", request.endpoint),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("PredictionServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py new file mode 100644 index 0000000000..9ec1369a05 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/__init__.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import PredictionServiceTransport +from .grpc import PredictionServiceGrpcTransport +from .grpc_asyncio import PredictionServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] +_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + +__all__ = ( + "PredictionServiceTransport", + "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py new file mode 100644 index 0000000000..311639daaf --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/base.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import prediction_service + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class PredictionServiceTransport(abc.ABC): + """Abstract transport class for PredictionService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.predict: gapic_v1.method.wrap_method( + self.predict, default_timeout=None, client_info=client_info, + ), + } + + @property + def predict( + self, + ) -> typing.Callable[ + [prediction_service.PredictRequest], + typing.Union[ + prediction_service.PredictResponse, + typing.Awaitable[prediction_service.PredictResponse], + ], + ]: + raise NotImplementedError() + + +__all__ = ("PredictionServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py new file mode 100644 index 0000000000..4fcfe5b442 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc.py @@ -0,0 +1,280 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import prediction_service + +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO + + +class PredictionServiceGrpcTransport(PredictionServiceTransport): + """gRPC backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: + r"""Return a callable for the predict method over gRPC. + + Perform an online prediction. + + Returns: + Callable[[~.PredictRequest], + ~.PredictResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PredictionService/Predict", + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs["predict"] + + +__all__ = ("PredictionServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..620f340813 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/prediction_service/transports/grpc_asyncio.py @@ -0,0 +1,285 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import prediction_service + +from .base import PredictionServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import PredictionServiceGrpcTransport + + +class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): + """gRPC AsyncIO backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: + r"""Return a callable for the predict method over gRPC. + + Perform an online prediction. + + Returns: + Callable[[~.PredictRequest], + Awaitable[~.PredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "predict" not in self._stubs: + self._stubs["predict"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.PredictionService/Predict", + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs["predict"] + + +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py new file mode 100644 index 0000000000..49e9cdf0a0 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/__init__.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .client import SpecialistPoolServiceClient +from .async_client import SpecialistPoolServiceAsyncClient + +__all__ = ( + "SpecialistPoolServiceClient", + "SpecialistPoolServiceAsyncClient", +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py new file mode 100644 index 0000000000..496f6aa319 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/async_client.py @@ -0,0 +1,639 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +import functools +import re +from typing import Dict, Sequence, Tuple, Type, Union +import pkg_resources + +import google.api_core.client_options as ClientOptions # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport +from .client import SpecialistPoolServiceClient + + +class SpecialistPoolServiceAsyncClient: + """A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + """ + + _client: SpecialistPoolServiceClient + + DEFAULT_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = SpecialistPoolServiceClient.DEFAULT_MTLS_ENDPOINT + + specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.specialist_pool_path + ) + parse_specialist_pool_path = staticmethod( + SpecialistPoolServiceClient.parse_specialist_pool_path + ) + + common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + SpecialistPoolServiceClient.parse_common_billing_account_path + ) + + common_folder_path = staticmethod(SpecialistPoolServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + SpecialistPoolServiceClient.parse_common_folder_path + ) + + common_organization_path = staticmethod( + SpecialistPoolServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + SpecialistPoolServiceClient.parse_common_organization_path + ) + + common_project_path = staticmethod(SpecialistPoolServiceClient.common_project_path) + parse_common_project_path = staticmethod( + SpecialistPoolServiceClient.parse_common_project_path + ) + + common_location_path = staticmethod( + SpecialistPoolServiceClient.common_location_path + ) + parse_common_location_path = staticmethod( + SpecialistPoolServiceClient.parse_common_location_path + ) + + from_service_account_info = SpecialistPoolServiceClient.from_service_account_info + from_service_account_file = SpecialistPoolServiceClient.from_service_account_file + from_service_account_json = from_service_account_file + + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + get_transport_class = functools.partial( + type(SpecialistPoolServiceClient).get_transport_class, + type(SpecialistPoolServiceClient), + ) + + def __init__( + self, + *, + credentials: credentials.Credentials = None, + transport: Union[str, SpecialistPoolServiceTransport] = "grpc_asyncio", + client_options: ClientOptions = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the specialist pool service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, ~.SpecialistPoolServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (ClientOptions): Custom options for the client. It + won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + + self._client = SpecialistPoolServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + async def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a SpecialistPool. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.CreateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.CreateSpecialistPool``. + parent (:class:`str`): + Required. The parent Project name for the new + SpecialistPool. The form is + ``projects/{project}/locations/{location}``. + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + specialist_pool (:class:`google.cloud.aiplatform_v1.types.SpecialistPool`): + Required. The SpecialistPool to + create. + + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.CreateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.create_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.CreateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + async def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: + r"""Gets a SpecialistPool. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.GetSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.GetSpecialistPool``. + name (:class:`str`): + Required. The name of the SpecialistPool resource. The + form is + + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.SpecialistPool: + SpecialistPool represents customers' + own workforce to work on their data + labeling jobs. It includes a group of + specialist managers who are responsible + for managing the labelers in this pool + as well as customers' data labeling jobs + associated with this pool. + Customers create specialist pool as well + as start data labeling jobs on Cloud, + managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.GetSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.get_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + async def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsAsyncPager: + r"""Lists SpecialistPools in a Location. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.ListSpecialistPoolsRequest`): + The request object. Request message for + ``SpecialistPoolService.ListSpecialistPools``. + parent (:class:`str`): + Required. The name of the SpecialistPool's parent + resource. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.specialist_pool_service.pagers.ListSpecialistPoolsAsyncPager: + Response message for + ``SpecialistPoolService.ListSpecialistPools``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.ListSpecialistPoolsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.list_specialist_pools, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListSpecialistPoolsAsyncPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Deletes a SpecialistPool as well as all Specialists + in the pool. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.DeleteSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.DeleteSpecialistPool``. + name (:class:`str`): + Required. The resource name of the SpecialistPool to + delete. Format: + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.delete_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + async def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> operation_async.AsyncOperation: + r"""Updates a SpecialistPool. + + Args: + request (:class:`google.cloud.aiplatform_v1.types.UpdateSpecialistPoolRequest`): + The request object. Request message for + ``SpecialistPoolService.UpdateSpecialistPool``. + specialist_pool (:class:`google.cloud.aiplatform_v1.types.SpecialistPool`): + Required. The SpecialistPool which + replaces the resource on the server. + + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to + the resource. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = gapic_v1.method_async.wrap_method( + self._client._transport.update_specialist_pool, + default_timeout=None, + client_info=DEFAULT_CLIENT_INFO, + ) + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), + ) + + # Send the request. + response = await rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.UpdateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("SpecialistPoolServiceAsyncClient",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py new file mode 100644 index 0000000000..c6429b54f8 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/client.py @@ -0,0 +1,841 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from distutils import util +import os +import re +from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union +import pkg_resources + +from google.api_core import client_options as client_options_lib # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.api_core import operation as ga_operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.protobuf import empty_pb2 as empty # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + +from .transports.base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .transports.grpc import SpecialistPoolServiceGrpcTransport +from .transports.grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport + + +class SpecialistPoolServiceClientMeta(type): + """Metaclass for the SpecialistPoolService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[SpecialistPoolServiceTransport]] + _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport + _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + + def get_transport_class( + cls, label: str = None, + ) -> Type[SpecialistPoolServiceTransport]: + """Return an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class SpecialistPoolServiceClient(metaclass=SpecialistPoolServiceClientMeta): + """A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Convert api endpoint to mTLS endpoint. + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + DEFAULT_ENDPOINT = "aiplatform.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> SpecialistPoolServiceTransport: + """Return the transport used by the client instance. + + Returns: + SpecialistPoolServiceTransport: The transport used by the client instance. + """ + return self._transport + + @staticmethod + def specialist_pool_path(project: str, location: str, specialist_pool: str,) -> str: + """Return a fully-qualified specialist_pool string.""" + return "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + + @staticmethod + def parse_specialist_pool_path(path: str) -> Dict[str, str]: + """Parse a specialist_pool path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/specialistPools/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path(billing_account: str,) -> str: + """Return a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path(folder: str,) -> str: + """Return a fully-qualified folder string.""" + return "folders/{folder}".format(folder=folder,) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path(organization: str,) -> str: + """Return a fully-qualified organization string.""" + return "organizations/{organization}".format(organization=organization,) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path(project: str,) -> str: + """Return a fully-qualified project string.""" + return "projects/{project}".format(project=project,) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path(project: str, location: str,) -> str: + """Return a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + def __init__( + self, + *, + credentials: Optional[credentials.Credentials] = None, + transport: Union[str, SpecialistPoolServiceTransport, None] = None, + client_options: Optional[client_options_lib.ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the specialist pool service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Union[str, SpecialistPoolServiceTransport]): The + transport to use. If set to None, a transport is chosen + automatically. + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. It won't take effect if a ``transport`` instance is provided. + (1) The ``api_endpoint`` property can be used to override the + default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT + environment variable can also be used to override the endpoint: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto switch to the + default mTLS endpoint if client certificate is present, this is + the default value). However, the ``api_endpoint`` property takes + precedence if provided. + (2) If GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide client certificate for mutual TLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + if isinstance(client_options, dict): + client_options = client_options_lib.from_dict(client_options) + if client_options is None: + client_options = client_options_lib.ClientOptions() + + # Create SSL credentials for mutual TLS if needed. + use_client_cert = bool( + util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) + ) + + client_cert_source_func = None + is_mtls = False + if use_client_cert: + if client_options.client_cert_source: + is_mtls = True + client_cert_source_func = client_options.client_cert_source + else: + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + else: + use_mtls_env = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_mtls_env == "never": + api_endpoint = self.DEFAULT_ENDPOINT + elif use_mtls_env == "always": + api_endpoint = self.DEFAULT_MTLS_ENDPOINT + elif use_mtls_env == "auto": + api_endpoint = ( + self.DEFAULT_MTLS_ENDPOINT if is_mtls else self.DEFAULT_ENDPOINT + ) + else: + raise MutualTLSChannelError( + "Unsupported GOOGLE_API_USE_MTLS_ENDPOINT value. Accepted values: never, auto, always" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + if isinstance(transport, SpecialistPoolServiceTransport): + # transport is a SpecialistPoolServiceTransport instance. + if credentials or client_options.credentials_file: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if client_options.scopes: + raise ValueError( + "When providing a transport instance, " + "provide its scopes directly." + ) + self._transport = transport + else: + Transport = type(self).get_transport_class(transport) + self._transport = Transport( + credentials=credentials, + credentials_file=client_options.credentials_file, + host=api_endpoint, + scopes=client_options.scopes, + client_cert_source_for_mtls=client_cert_source_func, + quota_project_id=client_options.quota_project_id, + client_info=client_info, + ) + + def create_specialist_pool( + self, + request: specialist_pool_service.CreateSpecialistPoolRequest = None, + *, + parent: str = None, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Creates a SpecialistPool. + + Args: + request (google.cloud.aiplatform_v1.types.CreateSpecialistPoolRequest): + The request object. Request message for + ``SpecialistPoolService.CreateSpecialistPool``. + parent (str): + Required. The parent Project name for the new + SpecialistPool. The form is + ``projects/{project}/locations/{location}``. + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + specialist_pool (google.cloud.aiplatform_v1.types.SpecialistPool): + Required. The SpecialistPool to + create. + + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, specialist_pool]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.CreateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.CreateSpecialistPoolRequest): + request = specialist_pool_service.CreateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + if specialist_pool is not None: + request.specialist_pool = specialist_pool + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.CreateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + def get_specialist_pool( + self, + request: specialist_pool_service.GetSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> specialist_pool.SpecialistPool: + r"""Gets a SpecialistPool. + + Args: + request (google.cloud.aiplatform_v1.types.GetSpecialistPoolRequest): + The request object. Request message for + ``SpecialistPoolService.GetSpecialistPool``. + name (str): + Required. The name of the SpecialistPool resource. The + form is + + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.types.SpecialistPool: + SpecialistPool represents customers' + own workforce to work on their data + labeling jobs. It includes a group of + specialist managers who are responsible + for managing the labelers in this pool + as well as customers' data labeling jobs + associated with this pool. + Customers create specialist pool as well + as start data labeling jobs on Cloud, + managers and labelers work with the jobs + using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.GetSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.GetSpecialistPoolRequest): + request = specialist_pool_service.GetSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Done; return the response. + return response + + def list_specialist_pools( + self, + request: specialist_pool_service.ListSpecialistPoolsRequest = None, + *, + parent: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> pagers.ListSpecialistPoolsPager: + r"""Lists SpecialistPools in a Location. + + Args: + request (google.cloud.aiplatform_v1.types.ListSpecialistPoolsRequest): + The request object. Request message for + ``SpecialistPoolService.ListSpecialistPools``. + parent (str): + Required. The name of the SpecialistPool's parent + resource. Format: + ``projects/{project}/locations/{location}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.cloud.aiplatform_v1.services.specialist_pool_service.pagers.ListSpecialistPoolsPager: + Response message for + ``SpecialistPoolService.ListSpecialistPools``. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.ListSpecialistPoolsRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.ListSpecialistPoolsRequest): + request = specialist_pool_service.ListSpecialistPoolsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_specialist_pools] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListSpecialistPoolsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_specialist_pool( + self, + request: specialist_pool_service.DeleteSpecialistPoolRequest = None, + *, + name: str = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Deletes a SpecialistPool as well as all Specialists + in the pool. + + Args: + request (google.cloud.aiplatform_v1.types.DeleteSpecialistPoolRequest): + The request object. Request message for + ``SpecialistPoolService.DeleteSpecialistPool``. + name (str): + Required. The resource name of the SpecialistPool to + delete. Format: + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: + + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); + + } + + The JSON representation for Empty is empty JSON + object {}. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.DeleteSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.DeleteSpecialistPoolRequest): + request = specialist_pool_service.DeleteSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + empty.Empty, + metadata_type=gca_operation.DeleteOperationMetadata, + ) + + # Done; return the response. + return response + + def update_specialist_pool( + self, + request: specialist_pool_service.UpdateSpecialistPoolRequest = None, + *, + specialist_pool: gca_specialist_pool.SpecialistPool = None, + update_mask: field_mask.FieldMask = None, + retry: retries.Retry = gapic_v1.method.DEFAULT, + timeout: float = None, + metadata: Sequence[Tuple[str, str]] = (), + ) -> ga_operation.Operation: + r"""Updates a SpecialistPool. + + Args: + request (google.cloud.aiplatform_v1.types.UpdateSpecialistPoolRequest): + The request object. Request message for + ``SpecialistPoolService.UpdateSpecialistPool``. + specialist_pool (google.cloud.aiplatform_v1.types.SpecialistPool): + Required. The SpecialistPool which + replaces the resource on the server. + + This corresponds to the ``specialist_pool`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to + the resource. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be :class:`google.cloud.aiplatform_v1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. + + """ + # Create or coerce a protobuf request object. + # Sanity check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([specialist_pool, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # Minor optimization to avoid making a copy if the user passes + # in a specialist_pool_service.UpdateSpecialistPoolRequest. + # There's no risk of modifying the input as we've already verified + # there are no flattened fields. + if not isinstance(request, specialist_pool_service.UpdateSpecialistPoolRequest): + request = specialist_pool_service.UpdateSpecialistPoolRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + + if specialist_pool is not None: + request.specialist_pool = specialist_pool + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_specialist_pool] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("specialist_pool.name", request.specialist_pool.name),) + ), + ) + + # Send the request. + response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + + # Wrap the response in an operation future. + response = ga_operation.from_gapic( + response, + self._transport.operations_client, + gca_specialist_pool.SpecialistPool, + metadata_type=specialist_pool_service.UpdateSpecialistPoolOperationMetadata, + ) + + # Done; return the response. + return response + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +__all__ = ("SpecialistPoolServiceClient",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py new file mode 100644 index 0000000000..b55e53169e --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, AsyncIterable, Awaitable, Callable, Iterable, Sequence, Tuple + +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service + + +class ListSpecialistPoolsPager: + """A pager for iterating through ``list_specialist_pools`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``specialist_pools`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListSpecialistPools`` requests and continue to iterate + through the ``specialist_pools`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., specialist_pool_service.ListSpecialistPoolsResponse], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListSpecialistPoolsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[specialist_pool_service.ListSpecialistPoolsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[specialist_pool.SpecialistPool]: + for page in self.pages: + yield from page.specialist_pools + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListSpecialistPoolsAsyncPager: + """A pager for iterating through ``list_specialist_pools`` requests. + + This class thinly wraps an initial + :class:`google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``specialist_pools`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListSpecialistPools`` requests and continue to iterate + through the ``specialist_pools`` field on the + corresponding responses. + + All the usual :class:`google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., Awaitable[specialist_pool_service.ListSpecialistPoolsResponse] + ], + request: specialist_pool_service.ListSpecialistPoolsRequest, + response: specialist_pool_service.ListSpecialistPoolsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.aiplatform_v1.types.ListSpecialistPoolsRequest): + The initial request object. + response (google.cloud.aiplatform_v1.types.ListSpecialistPoolsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = specialist_pool_service.ListSpecialistPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages( + self, + ) -> AsyncIterable[specialist_pool_service.ListSpecialistPoolsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method(self._request, metadata=self._metadata) + yield self._response + + def __aiter__(self) -> AsyncIterable[specialist_pool.SpecialistPool]: + async def async_generator(): + async for page in self.pages: + for response in page.specialist_pools: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py new file mode 100644 index 0000000000..1bb2fbf22a --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/__init__.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import OrderedDict +from typing import Dict, Type + +from .base import SpecialistPoolServiceTransport +from .grpc import SpecialistPoolServiceGrpcTransport +from .grpc_asyncio import SpecialistPoolServiceGrpcAsyncIOTransport + + +# Compile a registry of transports. +_transport_registry = ( + OrderedDict() +) # type: Dict[str, Type[SpecialistPoolServiceTransport]] +_transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport +_transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport + +__all__ = ( + "SpecialistPoolServiceTransport", + "SpecialistPoolServiceGrpcTransport", + "SpecialistPoolServiceGrpcAsyncIOTransport", +) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py new file mode 100644 index 0000000000..56de21b988 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/base.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import typing +import pkg_resources + +from google import auth # type: ignore +from google.api_core import exceptions # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google.api_core import retry as retries # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.auth import credentials # type: ignore + +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.longrunning import operations_pb2 as operations # type: ignore + + +try: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=pkg_resources.get_distribution( + "google-cloud-aiplatform", + ).version, + ) +except pkg_resources.DistributionNotFound: + DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo() + + +class SpecialistPoolServiceTransport(abc.ABC): + """Abstract transport class for SpecialistPoolService.""" + + AUTH_SCOPES = ("https://www.googleapis.com/auth/cloud-platform",) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: typing.Optional[str] = None, + scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES, + quota_project_id: typing.Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scope (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + """ + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = auth.load_credentials_from_file( + credentials_file, scopes=scopes, quota_project_id=quota_project_id + ) + + elif credentials is None: + credentials, _ = auth.default( + scopes=scopes, quota_project_id=quota_project_id + ) + + # Save the credentials. + self._credentials = credentials + + # Lifted into its own function so it can be stubbed out during tests. + self._prep_wrapped_messages(client_info) + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_specialist_pool: gapic_v1.method.wrap_method( + self.create_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.get_specialist_pool: gapic_v1.method.wrap_method( + self.get_specialist_pool, default_timeout=None, client_info=client_info, + ), + self.list_specialist_pools: gapic_v1.method.wrap_method( + self.list_specialist_pools, + default_timeout=None, + client_info=client_info, + ), + self.delete_specialist_pool: gapic_v1.method.wrap_method( + self.delete_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + self.update_specialist_pool: gapic_v1.method.wrap_method( + self.update_specialist_pool, + default_timeout=None, + client_info=client_info, + ), + } + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def create_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def get_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + typing.Union[ + specialist_pool.SpecialistPool, + typing.Awaitable[specialist_pool.SpecialistPool], + ], + ]: + raise NotImplementedError() + + @property + def list_specialist_pools( + self, + ) -> typing.Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + typing.Union[ + specialist_pool_service.ListSpecialistPoolsResponse, + typing.Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + @property + def update_specialist_pool( + self, + ) -> typing.Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + typing.Union[operations.Operation, typing.Awaitable[operations.Operation]], + ]: + raise NotImplementedError() + + +__all__ = ("SpecialistPoolServiceTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py new file mode 100644 index 0000000000..c9895648d2 --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc.py @@ -0,0 +1,418 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import grpc_helpers # type: ignore +from google.api_core import operations_v1 # type: ignore +from google.api_core import gapic_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore + +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO + + +class SpecialistPoolServiceGrpcTransport(SpecialistPoolServiceTransport): + """gRPC backend transport for SpecialistPoolService. + + A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Sequence[str] = None, + channel: grpc.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + channel (Optional[grpc.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: str = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service. + """ + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) + + # Return the client from cache. + return self._operations_client + + @property + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], operations.Operation + ]: + r"""Return a callable for the create specialist pool method over gRPC. + + Creates a SpecialistPool. + + Returns: + Callable[[~.CreateSpecialistPoolRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", + request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_specialist_pool"] + + @property + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + specialist_pool.SpecialistPool, + ]: + r"""Return a callable for the get specialist pool method over gRPC. + + Gets a SpecialistPool. + + Returns: + Callable[[~.GetSpecialistPoolRequest], + ~.SpecialistPool]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", + request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, + response_deserializer=specialist_pool.SpecialistPool.deserialize, + ) + return self._stubs["get_specialist_pool"] + + @property + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + specialist_pool_service.ListSpecialistPoolsResponse, + ]: + r"""Return a callable for the list specialist pools method over gRPC. + + Lists SpecialistPools in a Location. + + Returns: + Callable[[~.ListSpecialistPoolsRequest], + ~.ListSpecialistPoolsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", + request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, + response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, + ) + return self._stubs["list_specialist_pools"] + + @property + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], operations.Operation + ]: + r"""Return a callable for the delete specialist pool method over gRPC. + + Deletes a SpecialistPool as well as all Specialists + in the pool. + + Returns: + Callable[[~.DeleteSpecialistPoolRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", + request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_specialist_pool"] + + @property + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], operations.Operation + ]: + r"""Return a callable for the update specialist pool method over gRPC. + + Updates a SpecialistPool. + + Returns: + Callable[[~.UpdateSpecialistPoolRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", + request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["update_specialist_pool"] + + +__all__ = ("SpecialistPoolServiceGrpcTransport",) diff --git a/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py new file mode 100644 index 0000000000..566d0b022b --- /dev/null +++ b/google/cloud/aiplatform_v1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple + +from google.api_core import gapic_v1 # type: ignore +from google.api_core import grpc_helpers_async # type: ignore +from google.api_core import operations_v1 # type: ignore +from google import auth # type: ignore +from google.auth import credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore + +import grpc # type: ignore +from grpc.experimental import aio # type: ignore + +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.longrunning import operations_pb2 as operations # type: ignore + +from .base import SpecialistPoolServiceTransport, DEFAULT_CLIENT_INFO +from .grpc import SpecialistPoolServiceGrpcTransport + + +class SpecialistPoolServiceGrpcAsyncIOTransport(SpecialistPoolServiceTransport): + """gRPC AsyncIO backend transport for SpecialistPoolService. + + A service for creating and managing Customer SpecialistPools. + When customers start Data Labeling jobs, they can reuse/create + Specialist Pools to bring their own Specialists to label the + data. Customers can add/remove Managers for the Specialist Pool + on Cloud console, then Managers will get email notifications to + manage Specialists and tasks on CrowdCompute console. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + address (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + scopes = scopes or cls.AUTH_SCOPES + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "aiplatform.googleapis.com", + credentials: credentials.Credentials = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: aio.Channel = None, + api_mtls_endpoint: str = None, + client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, + ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, + quota_project_id=None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): The hostname to connect to. + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if ``channel`` is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[aio.Channel]): A ``Channel`` instance through + which to make calls. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or applicatin default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._ssl_channel_credentials = ssl_channel_credentials + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if channel: + # Sanity check: Ensure that channel and credentials are not both + # provided. + credentials = False + + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + elif api_mtls_endpoint: + host = ( + api_mtls_endpoint + if ":" in api_mtls_endpoint + else api_mtls_endpoint + ":443" + ) + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + ssl_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + ssl_credentials = SslCredentials().ssl_credentials + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=ssl_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + self._ssl_channel_credentials = ssl_credentials + else: + host = host if ":" in host else host + ":443" + + if credentials is None: + credentials, _ = auth.default( + scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id + ) + + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # create a new channel. The provided one is ignored. + self._grpc_channel = type(self).create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + ssl_credentials=self._ssl_channel_credentials, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Run the base constructor. + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes or self.AUTH_SCOPES, + quota_project_id=quota_project_id, + client_info=client_info, + ) + + self._stubs = {} + self._operations_client = None + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Sanity check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self.grpc_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def create_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.CreateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the create specialist pool method over gRPC. + + Creates a SpecialistPool. + + Returns: + Callable[[~.CreateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_specialist_pool" not in self._stubs: + self._stubs["create_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/CreateSpecialistPool", + request_serializer=specialist_pool_service.CreateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["create_specialist_pool"] + + @property + def get_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.GetSpecialistPoolRequest], + Awaitable[specialist_pool.SpecialistPool], + ]: + r"""Return a callable for the get specialist pool method over gRPC. + + Gets a SpecialistPool. + + Returns: + Callable[[~.GetSpecialistPoolRequest], + Awaitable[~.SpecialistPool]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_specialist_pool" not in self._stubs: + self._stubs["get_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/GetSpecialistPool", + request_serializer=specialist_pool_service.GetSpecialistPoolRequest.serialize, + response_deserializer=specialist_pool.SpecialistPool.deserialize, + ) + return self._stubs["get_specialist_pool"] + + @property + def list_specialist_pools( + self, + ) -> Callable[ + [specialist_pool_service.ListSpecialistPoolsRequest], + Awaitable[specialist_pool_service.ListSpecialistPoolsResponse], + ]: + r"""Return a callable for the list specialist pools method over gRPC. + + Lists SpecialistPools in a Location. + + Returns: + Callable[[~.ListSpecialistPoolsRequest], + Awaitable[~.ListSpecialistPoolsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_specialist_pools" not in self._stubs: + self._stubs["list_specialist_pools"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/ListSpecialistPools", + request_serializer=specialist_pool_service.ListSpecialistPoolsRequest.serialize, + response_deserializer=specialist_pool_service.ListSpecialistPoolsResponse.deserialize, + ) + return self._stubs["list_specialist_pools"] + + @property + def delete_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.DeleteSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the delete specialist pool method over gRPC. + + Deletes a SpecialistPool as well as all Specialists + in the pool. + + Returns: + Callable[[~.DeleteSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_specialist_pool" not in self._stubs: + self._stubs["delete_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/DeleteSpecialistPool", + request_serializer=specialist_pool_service.DeleteSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["delete_specialist_pool"] + + @property + def update_specialist_pool( + self, + ) -> Callable[ + [specialist_pool_service.UpdateSpecialistPoolRequest], + Awaitable[operations.Operation], + ]: + r"""Return a callable for the update specialist pool method over gRPC. + + Updates a SpecialistPool. + + Returns: + Callable[[~.UpdateSpecialistPoolRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_specialist_pool" not in self._stubs: + self._stubs["update_specialist_pool"] = self.grpc_channel.unary_unary( + "/google.cloud.aiplatform.v1.SpecialistPoolService/UpdateSpecialistPool", + request_serializer=specialist_pool_service.UpdateSpecialistPoolRequest.serialize, + response_deserializer=operations.Operation.FromString, + ) + return self._stubs["update_specialist_pool"] + + +__all__ = ("SpecialistPoolServiceGrpcAsyncIOTransport",) diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py new file mode 100644 index 0000000000..f073d451fe --- /dev/null +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -0,0 +1,361 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .user_action_reference import UserActionReference +from .annotation import Annotation +from .annotation_spec import AnnotationSpec +from .completion_stats import CompletionStats +from .encryption_spec import EncryptionSpec +from .io import ( + GcsSource, + GcsDestination, + BigQuerySource, + BigQueryDestination, + ContainerRegistryDestination, +) +from .machine_resources import ( + MachineSpec, + DedicatedResources, + AutomaticResources, + BatchDedicatedResources, + ResourcesConsumed, + DiskSpec, +) +from .manual_batch_tuning_parameters import ManualBatchTuningParameters +from .batch_prediction_job import BatchPredictionJob +from .env_var import EnvVar +from .custom_job import ( + CustomJob, + CustomJobSpec, + WorkerPoolSpec, + ContainerSpec, + PythonPackageSpec, + Scheduling, +) +from .data_item import DataItem +from .specialist_pool import SpecialistPool +from .data_labeling_job import ( + DataLabelingJob, + ActiveLearningConfig, + SampleConfig, + TrainingConfig, +) +from .dataset import ( + Dataset, + ImportDataConfig, + ExportDataConfig, +) +from .operation import ( + GenericOperationMetadata, + DeleteOperationMetadata, +) +from .deployed_model_ref import DeployedModelRef +from .model import ( + Model, + PredictSchemata, + ModelContainerSpec, + Port, +) +from .training_pipeline import ( + TrainingPipeline, + InputDataConfig, + FractionSplit, + FilterSplit, + PredefinedSplit, + TimestampSplit, +) +from .dataset_service import ( + CreateDatasetRequest, + CreateDatasetOperationMetadata, + GetDatasetRequest, + UpdateDatasetRequest, + ListDatasetsRequest, + ListDatasetsResponse, + DeleteDatasetRequest, + ImportDataRequest, + ImportDataResponse, + ImportDataOperationMetadata, + ExportDataRequest, + ExportDataResponse, + ExportDataOperationMetadata, + ListDataItemsRequest, + ListDataItemsResponse, + GetAnnotationSpecRequest, + ListAnnotationsRequest, + ListAnnotationsResponse, +) +from .endpoint import ( + Endpoint, + DeployedModel, +) +from .endpoint_service import ( + CreateEndpointRequest, + CreateEndpointOperationMetadata, + GetEndpointRequest, + ListEndpointsRequest, + ListEndpointsResponse, + UpdateEndpointRequest, + DeleteEndpointRequest, + DeployModelRequest, + DeployModelResponse, + DeployModelOperationMetadata, + UndeployModelRequest, + UndeployModelResponse, + UndeployModelOperationMetadata, +) +from .study import ( + Trial, + StudySpec, + Measurement, +) +from .hyperparameter_tuning_job import HyperparameterTuningJob +from .job_service import ( + CreateCustomJobRequest, + GetCustomJobRequest, + ListCustomJobsRequest, + ListCustomJobsResponse, + DeleteCustomJobRequest, + CancelCustomJobRequest, + CreateDataLabelingJobRequest, + GetDataLabelingJobRequest, + ListDataLabelingJobsRequest, + ListDataLabelingJobsResponse, + DeleteDataLabelingJobRequest, + CancelDataLabelingJobRequest, + CreateHyperparameterTuningJobRequest, + GetHyperparameterTuningJobRequest, + ListHyperparameterTuningJobsRequest, + ListHyperparameterTuningJobsResponse, + DeleteHyperparameterTuningJobRequest, + CancelHyperparameterTuningJobRequest, + CreateBatchPredictionJobRequest, + GetBatchPredictionJobRequest, + ListBatchPredictionJobsRequest, + ListBatchPredictionJobsResponse, + DeleteBatchPredictionJobRequest, + CancelBatchPredictionJobRequest, +) +from .migratable_resource import MigratableResource +from .migration_service import ( + SearchMigratableResourcesRequest, + SearchMigratableResourcesResponse, + BatchMigrateResourcesRequest, + MigrateResourceRequest, + BatchMigrateResourcesResponse, + MigrateResourceResponse, + BatchMigrateResourcesOperationMetadata, +) +from .model_evaluation import ModelEvaluation +from .model_evaluation_slice import ModelEvaluationSlice +from .model_service import ( + UploadModelRequest, + UploadModelOperationMetadata, + UploadModelResponse, + GetModelRequest, + ListModelsRequest, + ListModelsResponse, + UpdateModelRequest, + DeleteModelRequest, + ExportModelRequest, + ExportModelOperationMetadata, + ExportModelResponse, + GetModelEvaluationRequest, + ListModelEvaluationsRequest, + ListModelEvaluationsResponse, + GetModelEvaluationSliceRequest, + ListModelEvaluationSlicesRequest, + ListModelEvaluationSlicesResponse, +) +from .pipeline_service import ( + CreateTrainingPipelineRequest, + GetTrainingPipelineRequest, + ListTrainingPipelinesRequest, + ListTrainingPipelinesResponse, + DeleteTrainingPipelineRequest, + CancelTrainingPipelineRequest, +) +from .prediction_service import ( + PredictRequest, + PredictResponse, +) +from .specialist_pool_service import ( + CreateSpecialistPoolRequest, + CreateSpecialistPoolOperationMetadata, + GetSpecialistPoolRequest, + ListSpecialistPoolsRequest, + ListSpecialistPoolsResponse, + DeleteSpecialistPoolRequest, + UpdateSpecialistPoolRequest, + UpdateSpecialistPoolOperationMetadata, +) + +__all__ = ( + "AcceleratorType", + "UserActionReference", + "Annotation", + "AnnotationSpec", + "CompletionStats", + "EncryptionSpec", + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", + "JobState", + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", + "DiskSpec", + "ManualBatchTuningParameters", + "BatchPredictionJob", + "EnvVar", + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", + "DataItem", + "SpecialistPool", + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", + "Dataset", + "ImportDataConfig", + "ExportDataConfig", + "GenericOperationMetadata", + "DeleteOperationMetadata", + "DeployedModelRef", + "Model", + "PredictSchemata", + "ModelContainerSpec", + "Port", + "PipelineState", + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + "Endpoint", + "DeployedModel", + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", + "Trial", + "StudySpec", + "Measurement", + "HyperparameterTuningJob", + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", + "MigratableResource", + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", + "ModelEvaluation", + "ModelEvaluationSlice", + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", + "PredictRequest", + "PredictResponse", + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", +) diff --git a/google/cloud/aiplatform_v1/types/accelerator_type.py b/google/cloud/aiplatform_v1/types/accelerator_type.py new file mode 100644 index 0000000000..640436c38c --- /dev/null +++ b/google/cloud/aiplatform_v1/types/accelerator_type.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"AcceleratorType",}, +) + + +class AcceleratorType(proto.Enum): + r"""Represents a hardware accelerator type.""" + ACCELERATOR_TYPE_UNSPECIFIED = 0 + NVIDIA_TESLA_K80 = 1 + NVIDIA_TESLA_P100 = 2 + NVIDIA_TESLA_V100 = 3 + NVIDIA_TESLA_P4 = 4 + NVIDIA_TESLA_T4 = 5 + TPU_V2 = 6 + TPU_V3 = 7 + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/annotation.py b/google/cloud/aiplatform_v1/types/annotation.py new file mode 100644 index 0000000000..000ca49dcb --- /dev/null +++ b/google/cloud/aiplatform_v1/types/annotation.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import user_action_reference +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"Annotation",}, +) + + +class Annotation(proto.Message): + r"""Used to assign specific AnnotationSpec to a particular area + of a DataItem or the whole part of the DataItem. + + Attributes: + name (str): + Output only. Resource name of the Annotation. + payload_schema_uri (str): + Required. Google Cloud Storage URI points to a YAML file + describing + ``payload``. + The schema is defined as an `OpenAPI 3.0.2 Schema + Object `__. The schema files + that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/, + note that the chosen schema must be consistent with the + parent Dataset's + ``metadata``. + payload (google.protobuf.struct_pb2.Value): + Required. The schema of the payload can be found in + ``payload_schema``. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Annotation + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Annotation + was last updated. + etag (str): + Optional. Used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + annotation_source (google.cloud.aiplatform_v1.types.UserActionReference): + Output only. The source of the Annotation. + labels (Sequence[google.cloud.aiplatform_v1.types.Annotation.LabelsEntry]): + Optional. The labels with user-defined metadata to organize + your Annotations. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Annotation(System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Annotation: + + - "aiplatform.googleapis.com/annotation_set_name": + optional, name of the UI's annotation set this Annotation + belongs to. If not set, the Annotation is not visible in + the UI. + + - "aiplatform.googleapis.com/payload_schema": output only, + its value is the + [payload_schema's][google.cloud.aiplatform.v1.Annotation.payload_schema_uri] + title. + """ + + name = proto.Field(proto.STRING, number=1) + + payload_schema_uri = proto.Field(proto.STRING, number=2) + + payload = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + etag = proto.Field(proto.STRING, number=8) + + annotation_source = proto.Field( + proto.MESSAGE, number=5, message=user_action_reference.UserActionReference, + ) + + labels = proto.MapField(proto.STRING, proto.STRING, number=6) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/annotation_spec.py b/google/cloud/aiplatform_v1/types/annotation_spec.py new file mode 100644 index 0000000000..41f228ad72 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/annotation_spec.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"AnnotationSpec",}, +) + + +class AnnotationSpec(proto.Message): + r"""Identifies a concept with which DataItems may be annotated + with. + + Attributes: + name (str): + Output only. Resource name of the + AnnotationSpec. + display_name (str): + Required. The user-defined name of the + AnnotationSpec. The name can be up to 128 + characters long and can be consist of any UTF-8 + characters. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + AnnotationSpec was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when AnnotationSpec + was last updated. + etag (str): + Optional. Used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + create_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + etag = proto.Field(proto.STRING, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1/types/batch_prediction_job.py new file mode 100644 index 0000000000..d2d8f02203 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/batch_prediction_job.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import completion_stats as gca_completion_stats +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import ( + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters, +) +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"BatchPredictionJob",}, +) + + +class BatchPredictionJob(proto.Message): + r"""A job that uses a + ``Model`` to + produce predictions on multiple [input + instances][google.cloud.aiplatform.v1.BatchPredictionJob.input_config]. + If predictions for significant portion of the instances fail, the + job may finish without attempting predictions for all remaining + instances. + + Attributes: + name (str): + Output only. Resource name of the + BatchPredictionJob. + display_name (str): + Required. The user-defined name of this + BatchPredictionJob. + model (str): + Required. The name of the Model that produces + the predictions via this job, must share the + same ancestor Location. Starting this job has no + impact on any existing deployments of the Model + and their resources. + input_config (google.cloud.aiplatform_v1.types.BatchPredictionJob.InputConfig): + Required. Input configuration of the instances on which + predictions are performed. The schema of any single instance + may be specified via the + [Model's][google.cloud.aiplatform.v1.BatchPredictionJob.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``instance_schema_uri``. + model_parameters (google.protobuf.struct_pb2.Value): + The parameters that govern the predictions. The schema of + the parameters may be specified via the + [Model's][google.cloud.aiplatform.v1.BatchPredictionJob.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``parameters_schema_uri``. + output_config (google.cloud.aiplatform_v1.types.BatchPredictionJob.OutputConfig): + Required. The Configuration specifying where output + predictions should be written. The schema of any single + prediction may be specified as a concatenation of + [Model's][google.cloud.aiplatform.v1.BatchPredictionJob.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``instance_schema_uri`` + and + ``prediction_schema_uri``. + dedicated_resources (google.cloud.aiplatform_v1.types.BatchDedicatedResources): + The config of resources used by the Model during the batch + prediction. If the Model + ``supports`` + DEDICATED_RESOURCES this config may be provided (and the job + will use these resources), if the Model doesn't support + AUTOMATIC_RESOURCES, this config must be provided. + manual_batch_tuning_parameters (google.cloud.aiplatform_v1.types.ManualBatchTuningParameters): + Immutable. Parameters configuring the batch behavior. + Currently only applicable when + ``dedicated_resources`` + are used (in other cases AI Platform does the tuning + itself). + output_info (google.cloud.aiplatform_v1.types.BatchPredictionJob.OutputInfo): + Output only. Information further describing + the output of this job. + state (google.cloud.aiplatform_v1.types.JobState): + Output only. The detailed state of the job. + error (google.rpc.status_pb2.Status): + Output only. Only populated when the job's state is + JOB_STATE_FAILED or JOB_STATE_CANCELLED. + partial_failures (Sequence[google.rpc.status_pb2.Status]): + Output only. Partial failures encountered. + For example, single files that can't be read. + This field never exceeds 20 entries. + Status details fields contain standard GCP error + details. + resources_consumed (google.cloud.aiplatform_v1.types.ResourcesConsumed): + Output only. Information about resources that + had been consumed by this job. Provided in real + time at best effort basis, as well as a final + value once the job completes. + + Note: This field currently may be not populated + for batch predictions that use AutoML Models. + completion_stats (google.cloud.aiplatform_v1.types.CompletionStats): + Output only. Statistics on completed and + failed prediction instances. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the BatchPredictionJob + was created. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the BatchPredictionJob for the first + time entered the ``JOB_STATE_RUNNING`` state. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the BatchPredictionJob entered any of + the following states: ``JOB_STATE_SUCCEEDED``, + ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the BatchPredictionJob + was most recently updated. + labels (Sequence[google.cloud.aiplatform_v1.types.BatchPredictionJob.LabelsEntry]): + The labels with user-defined metadata to + organize BatchPredictionJobs. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key options for a + BatchPredictionJob. If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. + """ + + class InputConfig(proto.Message): + r"""Configures the input to + ``BatchPredictionJob``. + See + ``Model.supported_input_storage_formats`` + for Model's supported input formats, and how instances should be + expressed via any of them. + + Attributes: + gcs_source (google.cloud.aiplatform_v1.types.GcsSource): + The Cloud Storage location for the input + instances. + bigquery_source (google.cloud.aiplatform_v1.types.BigQuerySource): + The BigQuery location of the input table. + The schema of the table should be in the format + described by the given context OpenAPI Schema, + if one is provided. The table may contain + additional columns that are not described by the + schema, and they will be ignored. + instances_format (str): + Required. The format in which instances are given, must be + one of the + [Model's][google.cloud.aiplatform.v1.BatchPredictionJob.model] + ``supported_input_storage_formats``. + """ + + gcs_source = proto.Field( + proto.MESSAGE, number=2, oneof="source", message=io.GcsSource, + ) + + bigquery_source = proto.Field( + proto.MESSAGE, number=3, oneof="source", message=io.BigQuerySource, + ) + + instances_format = proto.Field(proto.STRING, number=1) + + class OutputConfig(proto.Message): + r"""Configures the output of + ``BatchPredictionJob``. + See + ``Model.supported_output_storage_formats`` + for supported output formats, and how predictions are expressed via + any of them. + + Attributes: + gcs_destination (google.cloud.aiplatform_v1.types.GcsDestination): + The Cloud Storage location of the directory where the output + is to be written to. In the given directory a new directory + is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen + ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` + and + ``prediction`` + schemata defined then each such file contains predictions as + per the + ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination (google.cloud.aiplatform_v1.types.BigQueryDestination): + The BigQuery project location where the output is to be + written to. In the given project a new dataset is created + with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both + ``instance`` + and + ``prediction`` + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + ```google.rpc.Status`` `__ represented as a STRUCT, + and containing only ``code`` and ``message``. + predictions_format (str): + Required. The format in which AI Platform gives the + predictions, must be one of the + [Model's][google.cloud.aiplatform.v1.BatchPredictionJob.model] + + ``supported_output_storage_formats``. + """ + + gcs_destination = proto.Field( + proto.MESSAGE, number=2, oneof="destination", message=io.GcsDestination, + ) + + bigquery_destination = proto.Field( + proto.MESSAGE, + number=3, + oneof="destination", + message=io.BigQueryDestination, + ) + + predictions_format = proto.Field(proto.STRING, number=1) + + class OutputInfo(proto.Message): + r"""Further describes this job's output. Supplements + ``output_config``. + + Attributes: + gcs_output_directory (str): + Output only. The full path of the Cloud + Storage directory created, into which the + prediction output is written. + bigquery_output_dataset (str): + Output only. The path of the BigQuery dataset created, in + ``bq://projectId.bqDatasetId`` format, into which the + prediction output is written. + """ + + gcs_output_directory = proto.Field( + proto.STRING, number=1, oneof="output_location" + ) + + bigquery_output_dataset = proto.Field( + proto.STRING, number=2, oneof="output_location" + ) + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + model = proto.Field(proto.STRING, number=3) + + input_config = proto.Field(proto.MESSAGE, number=4, message=InputConfig,) + + model_parameters = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + + output_config = proto.Field(proto.MESSAGE, number=6, message=OutputConfig,) + + dedicated_resources = proto.Field( + proto.MESSAGE, number=7, message=machine_resources.BatchDedicatedResources, + ) + + manual_batch_tuning_parameters = proto.Field( + proto.MESSAGE, + number=8, + message=gca_manual_batch_tuning_parameters.ManualBatchTuningParameters, + ) + + output_info = proto.Field(proto.MESSAGE, number=9, message=OutputInfo,) + + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + + error = proto.Field(proto.MESSAGE, number=11, message=status.Status,) + + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=12, message=status.Status, + ) + + resources_consumed = proto.Field( + proto.MESSAGE, number=13, message=machine_resources.ResourcesConsumed, + ) + + completion_stats = proto.Field( + proto.MESSAGE, number=14, message=gca_completion_stats.CompletionStats, + ) + + create_time = proto.Field(proto.MESSAGE, number=15, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=16, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=17, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=18, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=19) + + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/completion_stats.py b/google/cloud/aiplatform_v1/types/completion_stats.py new file mode 100644 index 0000000000..05648d82c4 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/completion_stats.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"CompletionStats",}, +) + + +class CompletionStats(proto.Message): + r"""Success and error statistics of processing multiple entities + (for example, DataItems or structured data rows) in batch. + + Attributes: + successful_count (int): + Output only. The number of entities that had + been processed successfully. + failed_count (int): + Output only. The number of entities for which + any error was encountered. + incomplete_count (int): + Output only. In cases when enough errors are + encountered a job, pipeline, or operation may be + failed as a whole. Below is the number of + entities for which the processing had not been + finished (either in successful or failed state). + Set to -1 if the number is unknown (for example, + the operation failed before the total entity + number could be collected). + """ + + successful_count = proto.Field(proto.INT64, number=1) + + failed_count = proto.Field(proto.INT64, number=2) + + incomplete_count = proto.Field(proto.INT64, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/custom_job.py b/google/cloud/aiplatform_v1/types/custom_job.py new file mode 100644 index 0000000000..c97cba6d82 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/custom_job.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import env_var +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import machine_resources +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CustomJob", + "CustomJobSpec", + "WorkerPoolSpec", + "ContainerSpec", + "PythonPackageSpec", + "Scheduling", + }, +) + + +class CustomJob(proto.Message): + r"""Represents a job that runs custom workloads such as a Docker + container or a Python package. A CustomJob can have multiple + worker pools and each worker pool can have its own machine and + input spec. A CustomJob will be cleaned up once the job enters + terminal state (failed or succeeded). + + Attributes: + name (str): + Output only. Resource name of a CustomJob. + display_name (str): + Required. The display name of the CustomJob. + The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + job_spec (google.cloud.aiplatform_v1.types.CustomJobSpec): + Required. Job spec. + state (google.cloud.aiplatform_v1.types.JobState): + Output only. The detailed state of the job. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the CustomJob was + created. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the CustomJob for the first time + entered the ``JOB_STATE_RUNNING`` state. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the CustomJob entered any of the + following states: ``JOB_STATE_SUCCEEDED``, + ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the CustomJob was most + recently updated. + error (google.rpc.status_pb2.Status): + Output only. Only populated when job's state is + ``JOB_STATE_FAILED`` or ``JOB_STATE_CANCELLED``. + labels (Sequence[google.cloud.aiplatform_v1.types.CustomJob.LabelsEntry]): + The labels with user-defined metadata to + organize CustomJobs. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key options for a + CustomJob. If this is set, then all resources + created by the CustomJob will be encrypted with + the provided encryption key. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + job_spec = proto.Field(proto.MESSAGE, number=4, message="CustomJobSpec",) + + state = proto.Field(proto.ENUM, number=5, enum=job_state.JobState,) + + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + + encryption_spec = proto.Field( + proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, + ) + + +class CustomJobSpec(proto.Message): + r"""Represents the spec of a CustomJob. + + Attributes: + worker_pool_specs (Sequence[google.cloud.aiplatform_v1.types.WorkerPoolSpec]): + Required. The spec of the worker pools + including machine type and Docker image. + scheduling (google.cloud.aiplatform_v1.types.Scheduling): + Scheduling options for a CustomJob. + service_account (str): + Specifies the service account for workload + run-as account. Users submitting jobs must have + act-as permission on this run-as account. If + unspecified, the AI Platform Custom Code Service + Agent for the CustomJob's project is used. + network (str): + The full name of the Compute Engine + `network `__ + to which the Job should be peered. For example, + ``projects/12345/global/networks/myVPC``. + `Format `__ + is of the form + ``projects/{project}/global/networks/{network}``. Where + {project} is a project number, as in ``12345``, and + {network} is a network name. + + Private services access must already be configured for the + network. If left unspecified, the job is not peered with any + network. + base_output_directory (google.cloud.aiplatform_v1.types.GcsDestination): + The Cloud Storage location to store the output of this + CustomJob or HyperparameterTuningJob. For + HyperparameterTuningJob, the baseOutputDirectory of each + child CustomJob backing a Trial is set to a subdirectory of + name ``id`` under its + parent HyperparameterTuningJob's baseOutputDirectory. + + The following AI Platform environment variables will be + passed to containers or python modules when this field is + set: + + For CustomJob: + + - AIP_MODEL_DIR = ``/model/`` + - AIP_CHECKPOINT_DIR = + ``/checkpoints/`` + - AIP_TENSORBOARD_LOG_DIR = + ``/logs/`` + + For CustomJob backing a Trial of HyperparameterTuningJob: + + - AIP_MODEL_DIR = + ``//model/`` + - AIP_CHECKPOINT_DIR = + ``//checkpoints/`` + - AIP_TENSORBOARD_LOG_DIR = + ``//logs/`` + """ + + worker_pool_specs = proto.RepeatedField( + proto.MESSAGE, number=1, message="WorkerPoolSpec", + ) + + scheduling = proto.Field(proto.MESSAGE, number=3, message="Scheduling",) + + service_account = proto.Field(proto.STRING, number=4) + + network = proto.Field(proto.STRING, number=5) + + base_output_directory = proto.Field( + proto.MESSAGE, number=6, message=io.GcsDestination, + ) + + +class WorkerPoolSpec(proto.Message): + r"""Represents the spec of a worker pool in a job. + + Attributes: + container_spec (google.cloud.aiplatform_v1.types.ContainerSpec): + The custom container task. + python_package_spec (google.cloud.aiplatform_v1.types.PythonPackageSpec): + The Python packaged task. + machine_spec (google.cloud.aiplatform_v1.types.MachineSpec): + Optional. Immutable. The specification of a + single machine. + replica_count (int): + Optional. The number of worker replicas to + use for this worker pool. + disk_spec (google.cloud.aiplatform_v1.types.DiskSpec): + Disk spec. + """ + + container_spec = proto.Field( + proto.MESSAGE, number=6, oneof="task", message="ContainerSpec", + ) + + python_package_spec = proto.Field( + proto.MESSAGE, number=7, oneof="task", message="PythonPackageSpec", + ) + + machine_spec = proto.Field( + proto.MESSAGE, number=1, message=machine_resources.MachineSpec, + ) + + replica_count = proto.Field(proto.INT64, number=2) + + disk_spec = proto.Field( + proto.MESSAGE, number=5, message=machine_resources.DiskSpec, + ) + + +class ContainerSpec(proto.Message): + r"""The spec of a Container. + + Attributes: + image_uri (str): + Required. The URI of a container image in the + Container Registry that is to be run on each + worker replica. + command (Sequence[str]): + The command to be invoked when the container + is started. It overrides the entrypoint + instruction in Dockerfile when provided. + args (Sequence[str]): + The arguments to be passed when starting the + container. + env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]): + Environment variables to be passed to the + container. + """ + + image_uri = proto.Field(proto.STRING, number=1) + + command = proto.RepeatedField(proto.STRING, number=2) + + args = proto.RepeatedField(proto.STRING, number=3) + + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + + +class PythonPackageSpec(proto.Message): + r"""The spec of a Python packaged code. + + Attributes: + executor_image_uri (str): + Required. The URI of a container image in the + Container Registry that will run the provided + python package. AI Platform provides wide range + of executor images with pre-installed packages + to meet users' various use cases. Only one of + the provided images can be set here. + package_uris (Sequence[str]): + Required. The Google Cloud Storage location + of the Python package files which are the + training program and its dependent packages. The + maximum number of package URIs is 100. + python_module (str): + Required. The Python module name to run after + installing the packages. + args (Sequence[str]): + Command line arguments to be passed to the + Python task. + env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]): + Environment variables to be passed to the + python module. + """ + + executor_image_uri = proto.Field(proto.STRING, number=1) + + package_uris = proto.RepeatedField(proto.STRING, number=2) + + python_module = proto.Field(proto.STRING, number=3) + + args = proto.RepeatedField(proto.STRING, number=4) + + env = proto.RepeatedField(proto.MESSAGE, number=5, message=env_var.EnvVar,) + + +class Scheduling(proto.Message): + r"""All parameters related to queuing and scheduling of custom + jobs. + + Attributes: + timeout (google.protobuf.duration_pb2.Duration): + The maximum job running time. The default is + 7 days. + restart_job_on_worker_restart (bool): + Restarts the entire CustomJob if a worker + gets restarted. This feature can be used by + distributed training jobs that are not resilient + to workers leaving and joining a job. + """ + + timeout = proto.Field(proto.MESSAGE, number=1, message=duration.Duration,) + + restart_job_on_worker_restart = proto.Field(proto.BOOL, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/data_item.py b/google/cloud/aiplatform_v1/types/data_item.py new file mode 100644 index 0000000000..20ff14a0d8 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/data_item.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"DataItem",}, +) + + +class DataItem(proto.Message): + r"""A piece of data in a Dataset. Could be an image, a video, a + document or plain text. + + Attributes: + name (str): + Output only. The resource name of the + DataItem. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this DataItem was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this DataItem was + last updated. + labels (Sequence[google.cloud.aiplatform_v1.types.DataItem.LabelsEntry]): + Optional. The labels with user-defined + metadata to organize your DataItems. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. No more than 64 user labels can be + associated with one DataItem(System labels are + excluded). + + See https://goo.gl/xmQnxf for more information + and examples of labels. System reserved label + keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + payload (google.protobuf.struct_pb2.Value): + Required. The data that the DataItem represents (for + example, an image or a text snippet). The schema of the + payload is stored in the parent Dataset's [metadata + schema's][google.cloud.aiplatform.v1.Dataset.metadata_schema_uri] + dataItemSchemaUri field. + etag (str): + Optional. Used to perform consistent read- + odify-write updates. If not set, a blind + "overwrite" update happens. + """ + + name = proto.Field(proto.STRING, number=1) + + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=3) + + payload = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + + etag = proto.Field(proto.STRING, number=7) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/data_labeling_job.py b/google/cloud/aiplatform_v1/types/data_labeling_job.py new file mode 100644 index 0000000000..e1058737bf --- /dev/null +++ b/google/cloud/aiplatform_v1/types/data_labeling_job.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import job_state +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "DataLabelingJob", + "ActiveLearningConfig", + "SampleConfig", + "TrainingConfig", + }, +) + + +class DataLabelingJob(proto.Message): + r"""DataLabelingJob is used to trigger a human labeling job on + unlabeled data from the following Dataset: + + Attributes: + name (str): + Output only. Resource name of the + DataLabelingJob. + display_name (str): + Required. The user-defined name of the + DataLabelingJob. The name can be up to 128 + characters long and can be consist of any UTF-8 + characters. + Display name of a DataLabelingJob. + datasets (Sequence[str]): + Required. Dataset resource names. Right now we only support + labeling from a single Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + annotation_labels (Sequence[google.cloud.aiplatform_v1.types.DataLabelingJob.AnnotationLabelsEntry]): + Labels to assign to annotations generated by + this DataLabelingJob. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. See https://goo.gl/xmQnxf for more + information and examples of labels. System + reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + labeler_count (int): + Required. Number of labelers to work on each + DataItem. + instruction_uri (str): + Required. The Google Cloud Storage location + of the instruction pdf. This pdf is shared with + labelers, and provides detailed description on + how to label DataItems in Datasets. + inputs_schema_uri (str): + Required. Points to a YAML file stored on + Google Cloud Storage describing the config for a + specific type of DataLabelingJob. The schema + files that can be used here are found in the + https://storage.googleapis.com/google-cloud- + aiplatform bucket in the + /schema/datalabelingjob/inputs/ folder. + inputs (google.protobuf.struct_pb2.Value): + Required. Input config parameters for the + DataLabelingJob. + state (google.cloud.aiplatform_v1.types.JobState): + Output only. The detailed state of the job. + labeling_progress (int): + Output only. Current labeling job progress percentage scaled + in interval [0, 100], indicating the percentage of DataItems + that has been finished. + current_spend (google.type.money_pb2.Money): + Output only. Estimated cost(in US dollars) + that the DataLabelingJob has incurred to date. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + DataLabelingJob was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + DataLabelingJob was updated most recently. + error (google.rpc.status_pb2.Status): + Output only. DataLabelingJob errors. It is only populated + when job's state is ``JOB_STATE_FAILED`` or + ``JOB_STATE_CANCELLED``. + labels (Sequence[google.cloud.aiplatform_v1.types.DataLabelingJob.LabelsEntry]): + The labels with user-defined metadata to organize your + DataLabelingJobs. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each DataLabelingJob: + + - "aiplatform.googleapis.com/schema": output only, its + value is the + ``inputs_schema``'s + title. + specialist_pools (Sequence[str]): + The SpecialistPools' resource names + associated with this job. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key spec for a + DataLabelingJob. If set, this DataLabelingJob + will be secured by this key. + Note: Annotations created in the DataLabelingJob + are associated with the EncryptionSpec of the + Dataset they are exported to. + active_learning_config (google.cloud.aiplatform_v1.types.ActiveLearningConfig): + Parameters that configure the active learning + pipeline. Active learning will label the data + incrementally via several iterations. For every + iteration, it will select a batch of data based + on the sampling strategy. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + datasets = proto.RepeatedField(proto.STRING, number=3) + + annotation_labels = proto.MapField(proto.STRING, proto.STRING, number=12) + + labeler_count = proto.Field(proto.INT32, number=4) + + instruction_uri = proto.Field(proto.STRING, number=5) + + inputs_schema_uri = proto.Field(proto.STRING, number=6) + + inputs = proto.Field(proto.MESSAGE, number=7, message=struct.Value,) + + state = proto.Field(proto.ENUM, number=8, enum=job_state.JobState,) + + labeling_progress = proto.Field(proto.INT32, number=13) + + current_spend = proto.Field(proto.MESSAGE, number=14, message=money.Money,) + + create_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=10, message=timestamp.Timestamp,) + + error = proto.Field(proto.MESSAGE, number=22, message=status.Status,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=11) + + specialist_pools = proto.RepeatedField(proto.STRING, number=16) + + encryption_spec = proto.Field( + proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, + ) + + active_learning_config = proto.Field( + proto.MESSAGE, number=21, message="ActiveLearningConfig", + ) + + +class ActiveLearningConfig(proto.Message): + r"""Parameters that configure the active learning pipeline. + Active learning will label the data incrementally by several + iterations. For every iteration, it will select a batch of data + based on the sampling strategy. + + Attributes: + max_data_item_count (int): + Max number of human labeled DataItems. + max_data_item_percentage (int): + Max percent of total DataItems for human + labeling. + sample_config (google.cloud.aiplatform_v1.types.SampleConfig): + Active learning data sampling config. For + every active learning labeling iteration, it + will select a batch of data based on the + sampling strategy. + training_config (google.cloud.aiplatform_v1.types.TrainingConfig): + CMLE training config. For every active + learning labeling iteration, system will train a + machine learning model on CMLE. The trained + model will be used by data sampling algorithm to + select DataItems. + """ + + max_data_item_count = proto.Field( + proto.INT64, number=1, oneof="human_labeling_budget" + ) + + max_data_item_percentage = proto.Field( + proto.INT32, number=2, oneof="human_labeling_budget" + ) + + sample_config = proto.Field(proto.MESSAGE, number=3, message="SampleConfig",) + + training_config = proto.Field(proto.MESSAGE, number=4, message="TrainingConfig",) + + +class SampleConfig(proto.Message): + r"""Active learning data sampling config. For every active + learning labeling iteration, it will select a batch of data + based on the sampling strategy. + + Attributes: + initial_batch_sample_percentage (int): + The percentage of data needed to be labeled + in the first batch. + following_batch_sample_percentage (int): + The percentage of data needed to be labeled + in each following batch (except the first + batch). + sample_strategy (google.cloud.aiplatform_v1.types.SampleConfig.SampleStrategy): + Field to choose sampling strategy. Sampling + strategy will decide which data should be + selected for human labeling in every batch. + """ + + class SampleStrategy(proto.Enum): + r"""Sample strategy decides which subset of DataItems should be + selected for human labeling in every batch. + """ + SAMPLE_STRATEGY_UNSPECIFIED = 0 + UNCERTAINTY = 1 + + initial_batch_sample_percentage = proto.Field( + proto.INT32, number=1, oneof="initial_batch_sample_size" + ) + + following_batch_sample_percentage = proto.Field( + proto.INT32, number=3, oneof="following_batch_sample_size" + ) + + sample_strategy = proto.Field(proto.ENUM, number=5, enum=SampleStrategy,) + + +class TrainingConfig(proto.Message): + r"""CMLE training config. For every active learning labeling + iteration, system will train a machine learning model on CMLE. + The trained model will be used by data sampling algorithm to + select DataItems. + + Attributes: + timeout_training_milli_hours (int): + The timeout hours for the CMLE training job, + expressed in milli hours i.e. 1,000 value in + this field means 1 hour. + """ + + timeout_training_milli_hours = proto.Field(proto.INT64, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/dataset.py b/google/cloud/aiplatform_v1/types/dataset.py new file mode 100644 index 0000000000..2f75dce0d5 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/dataset.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import io +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={"Dataset", "ImportDataConfig", "ExportDataConfig",}, +) + + +class Dataset(proto.Message): + r"""A collection of DataItems and Annotations on them. + + Attributes: + name (str): + Output only. The resource name of the + Dataset. + display_name (str): + Required. The user-defined name of the + Dataset. The name can be up to 128 characters + long and can be consist of any UTF-8 characters. + metadata_schema_uri (str): + Required. Points to a YAML file stored on + Google Cloud Storage describing additional + information about the Dataset. The schema is + defined as an OpenAPI 3.0.2 Schema Object. The + schema files that can be used here are found in + gs://google-cloud- + aiplatform/schema/dataset/metadata/. + metadata (google.protobuf.struct_pb2.Value): + Required. Additional information about the + Dataset. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Dataset was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Dataset was + last updated. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + labels (Sequence[google.cloud.aiplatform_v1.types.Dataset.LabelsEntry]): + The labels with user-defined metadata to organize your + Datasets. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. No more than 64 user labels can be + associated with one Dataset (System labels are excluded). + + See https://goo.gl/xmQnxf for more information and examples + of labels. System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. Following + system labels exist for each Dataset: + + - "aiplatform.googleapis.com/dataset_metadata_schema": + output only, its value is the + [metadata_schema's][google.cloud.aiplatform.v1.Dataset.metadata_schema_uri] + title. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key spec for a + Dataset. If set, this Dataset and all sub- + resources of this Dataset will be secured by + this key. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + metadata_schema_uri = proto.Field(proto.STRING, number=3) + + metadata = proto.Field(proto.MESSAGE, number=8, message=struct.Value,) + + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + etag = proto.Field(proto.STRING, number=6) + + labels = proto.MapField(proto.STRING, proto.STRING, number=7) + + encryption_spec = proto.Field( + proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, + ) + + +class ImportDataConfig(proto.Message): + r"""Describes the location from where we import data into a + Dataset, together with the labels that will be applied to the + DataItems and the Annotations. + + Attributes: + gcs_source (google.cloud.aiplatform_v1.types.GcsSource): + The Google Cloud Storage location for the + input content. + data_item_labels (Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig.DataItemLabelsEntry]): + Labels that will be applied to newly imported DataItems. If + an identical DataItem as one being imported already exists + in the Dataset, then these labels will be appended to these + of the already existing one, and if labels with identical + key is imported before, the old label value will be + overwritten. If two DataItems are identical in the same + import data operation, the labels will be combined and if + key collision happens in this case, one of the values will + be picked randomly. Two DataItems are considered identical + if their content bytes are identical (e.g. image bytes or + pdf bytes). These labels will be overridden by Annotation + labels specified inside index file referenced by + ``import_schema_uri``, + e.g. jsonl file. + import_schema_uri (str): + Required. Points to a YAML file stored on Google Cloud + Storage describing the import format. Validation will be + done against the schema. The schema is defined as an + `OpenAPI 3.0.2 Schema + Object `__. + """ + + gcs_source = proto.Field( + proto.MESSAGE, number=1, oneof="source", message=io.GcsSource, + ) + + data_item_labels = proto.MapField(proto.STRING, proto.STRING, number=2) + + import_schema_uri = proto.Field(proto.STRING, number=4) + + +class ExportDataConfig(proto.Message): + r"""Describes what part of the Dataset is to be exported, the + destination of the export and how to export. + + Attributes: + gcs_destination (google.cloud.aiplatform_v1.types.GcsDestination): + The Google Cloud Storage location where the output is to be + written to. In the given directory a new directory will be + created with name: + ``export-data--`` + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All export output will be written into that + directory. Inside that directory, annotations with the same + schema will be grouped into sub directories which are named + with the corresponding annotations' schema title. Inside + these sub directories, a schema.yaml will be created to + describe the output format. + annotations_filter (str): + A filter on Annotations of the Dataset. Only Annotations on + to-be-exported DataItems(specified by [data_items_filter][]) + that match this filter will be exported. The filter syntax + is the same as in + ``ListAnnotations``. + """ + + gcs_destination = proto.Field( + proto.MESSAGE, number=1, oneof="destination", message=io.GcsDestination, + ) + + annotations_filter = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/dataset_service.py b/google/cloud/aiplatform_v1/types/dataset_service.py new file mode 100644 index 0000000000..ccc8cce600 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/dataset_service.py @@ -0,0 +1,446 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import annotation +from google.cloud.aiplatform_v1.types import data_item +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateDatasetRequest", + "CreateDatasetOperationMetadata", + "GetDatasetRequest", + "UpdateDatasetRequest", + "ListDatasetsRequest", + "ListDatasetsResponse", + "DeleteDatasetRequest", + "ImportDataRequest", + "ImportDataResponse", + "ImportDataOperationMetadata", + "ExportDataRequest", + "ExportDataResponse", + "ExportDataOperationMetadata", + "ListDataItemsRequest", + "ListDataItemsResponse", + "GetAnnotationSpecRequest", + "ListAnnotationsRequest", + "ListAnnotationsResponse", + }, +) + + +class CreateDatasetRequest(proto.Message): + r"""Request message for + ``DatasetService.CreateDataset``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + Dataset in. Format: + ``projects/{project}/locations/{location}`` + dataset (google.cloud.aiplatform_v1.types.Dataset): + Required. The Dataset to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + dataset = proto.Field(proto.MESSAGE, number=2, message=gca_dataset.Dataset,) + + +class CreateDatasetOperationMetadata(proto.Message): + r"""Runtime operation information for + ``DatasetService.CreateDataset``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class GetDatasetRequest(proto.Message): + r"""Request message for + ``DatasetService.GetDataset``. + + Attributes: + name (str): + Required. The name of the Dataset resource. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + name = proto.Field(proto.STRING, number=1) + + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class UpdateDatasetRequest(proto.Message): + r"""Request message for + ``DatasetService.UpdateDataset``. + + Attributes: + dataset (google.cloud.aiplatform_v1.types.Dataset): + Required. The Dataset which replaces the + resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For the + ``FieldMask`` definition, see + `FieldMask `__. + Updatable fields: + + - ``display_name`` + - ``description`` + - ``labels`` + """ + + dataset = proto.Field(proto.MESSAGE, number=1, message=gca_dataset.Dataset,) + + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class ListDatasetsRequest(proto.Message): + r"""Request message for + ``DatasetService.ListDatasets``. + + Attributes: + parent (str): + Required. The name of the Dataset's parent resource. Format: + ``projects/{project}/locations/{location}`` + filter (str): + An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. + + - ``display_name``: supports = and != + - ``metadata_schema_uri``: supports = and != + - ``labels`` supports general map functions that is: + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. + + Some examples: + + - ``displayName="myDisplayName"`` + - ``labels.myKey="myValue"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. Supported fields: + + - ``display_name`` + - ``create_time`` + - ``update_time`` + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListDatasetsResponse(proto.Message): + r"""Response message for + ``DatasetService.ListDatasets``. + + Attributes: + datasets (Sequence[google.cloud.aiplatform_v1.types.Dataset]): + A list of Datasets that matches the specified + filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + datasets = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_dataset.Dataset, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteDatasetRequest(proto.Message): + r"""Request message for + ``DatasetService.DeleteDataset``. + + Attributes: + name (str): + Required. The resource name of the Dataset to delete. + Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ImportDataRequest(proto.Message): + r"""Request message for + ``DatasetService.ImportData``. + + Attributes: + name (str): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + import_configs (Sequence[google.cloud.aiplatform_v1.types.ImportDataConfig]): + Required. The desired input locations. The + contents of all input locations will be imported + in one batch. + """ + + name = proto.Field(proto.STRING, number=1) + + import_configs = proto.RepeatedField( + proto.MESSAGE, number=2, message=gca_dataset.ImportDataConfig, + ) + + +class ImportDataResponse(proto.Message): + r"""Response message for + ``DatasetService.ImportData``. + """ + + +class ImportDataOperationMetadata(proto.Message): + r"""Runtime operation information for + ``DatasetService.ImportData``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class ExportDataRequest(proto.Message): + r"""Request message for + ``DatasetService.ExportData``. + + Attributes: + name (str): + Required. The name of the Dataset resource. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + export_config (google.cloud.aiplatform_v1.types.ExportDataConfig): + Required. The desired output location. + """ + + name = proto.Field(proto.STRING, number=1) + + export_config = proto.Field( + proto.MESSAGE, number=2, message=gca_dataset.ExportDataConfig, + ) + + +class ExportDataResponse(proto.Message): + r"""Response message for + ``DatasetService.ExportData``. + + Attributes: + exported_files (Sequence[str]): + All of the files that are exported in this + export operation. + """ + + exported_files = proto.RepeatedField(proto.STRING, number=1) + + +class ExportDataOperationMetadata(proto.Message): + r"""Runtime operation information for + ``DatasetService.ExportData``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + gcs_output_directory (str): + A Google Cloud Storage directory which path + ends with '/'. The exported data is stored in + the directory. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + gcs_output_directory = proto.Field(proto.STRING, number=2) + + +class ListDataItemsRequest(proto.Message): + r"""Request message for + ``DatasetService.ListDataItems``. + + Attributes: + parent (str): + Required. The resource name of the Dataset to list DataItems + from. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}`` + filter (str): + The standard list filter. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + order_by (str): + A comma-separated list of fields to order by, + sorted in ascending order. Use "desc" after a + field name for descending. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListDataItemsResponse(proto.Message): + r"""Response message for + ``DatasetService.ListDataItems``. + + Attributes: + data_items (Sequence[google.cloud.aiplatform_v1.types.DataItem]): + A list of DataItems that matches the + specified filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + data_items = proto.RepeatedField( + proto.MESSAGE, number=1, message=data_item.DataItem, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class GetAnnotationSpecRequest(proto.Message): + r"""Request message for + ``DatasetService.GetAnnotationSpec``. + + Attributes: + name (str): + Required. The name of the AnnotationSpec resource. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + name = proto.Field(proto.STRING, number=1) + + read_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class ListAnnotationsRequest(proto.Message): + r"""Request message for + ``DatasetService.ListAnnotations``. + + Attributes: + parent (str): + Required. The resource name of the DataItem to list + Annotations from. Format: + + ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + filter (str): + The standard list filter. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + order_by (str): + A comma-separated list of fields to order by, + sorted in ascending order. Use "desc" after a + field name for descending. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListAnnotationsResponse(proto.Message): + r"""Response message for + ``DatasetService.ListAnnotations``. + + Attributes: + annotations (Sequence[google.cloud.aiplatform_v1.types.Annotation]): + A list of Annotations that matches the + specified filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + annotations = proto.RepeatedField( + proto.MESSAGE, number=1, message=annotation.Annotation, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/deployed_model_ref.py b/google/cloud/aiplatform_v1/types/deployed_model_ref.py new file mode 100644 index 0000000000..2d53610ed5 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/deployed_model_ref.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"DeployedModelRef",}, +) + + +class DeployedModelRef(proto.Message): + r"""Points to a DeployedModel. + + Attributes: + endpoint (str): + Immutable. A resource name of an Endpoint. + deployed_model_id (str): + Immutable. An ID of a DeployedModel in the + above Endpoint. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + deployed_model_id = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/encryption_spec.py b/google/cloud/aiplatform_v1/types/encryption_spec.py new file mode 100644 index 0000000000..ae908d4b72 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/encryption_spec.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"EncryptionSpec",}, +) + + +class EncryptionSpec(proto.Message): + r"""Represents a customer-managed encryption key spec that can be + applied to a top-level resource. + + Attributes: + kms_key_name (str): + Required. The Cloud KMS resource identifier of the customer + managed encryption key used to protect a resource. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + """ + + kms_key_name = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/endpoint.py b/google/cloud/aiplatform_v1/types/endpoint.py new file mode 100644 index 0000000000..5cbe3c1b1d --- /dev/null +++ b/google/cloud/aiplatform_v1/types/endpoint.py @@ -0,0 +1,200 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import machine_resources +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"Endpoint", "DeployedModel",}, +) + + +class Endpoint(proto.Message): + r"""Models are deployed into it, and afterwards Endpoint is + called to obtain predictions and explanations. + + Attributes: + name (str): + Output only. The resource name of the + Endpoint. + display_name (str): + Required. The display name of the Endpoint. + The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + description (str): + The description of the Endpoint. + deployed_models (Sequence[google.cloud.aiplatform_v1.types.DeployedModel]): + Output only. The models deployed in this Endpoint. To add or + remove DeployedModels use + ``EndpointService.DeployModel`` + and + ``EndpointService.UndeployModel`` + respectively. + traffic_split (Sequence[google.cloud.aiplatform_v1.types.Endpoint.TrafficSplitEntry]): + A map from a DeployedModel's ID to the + percentage of this Endpoint's traffic that + should be forwarded to that DeployedModel. + If a DeployedModel's ID is not listed in this + map, then it receives no traffic. + + The traffic percentage values must add up to + 100, or map must be empty if the Endpoint is to + not accept any traffic at a moment. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + labels (Sequence[google.cloud.aiplatform_v1.types.Endpoint.LabelsEntry]): + The labels with user-defined metadata to + organize your Endpoints. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Endpoint was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Endpoint was + last updated. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key spec for an + Endpoint. If set, this Endpoint and all sub- + resources of this Endpoint will be secured by + this key. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=4, message="DeployedModel", + ) + + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=5) + + etag = proto.Field(proto.STRING, number=6) + + labels = proto.MapField(proto.STRING, proto.STRING, number=7) + + create_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + + encryption_spec = proto.Field( + proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, + ) + + +class DeployedModel(proto.Message): + r"""A deployment of a Model. Endpoints contain one or more + DeployedModels. + + Attributes: + dedicated_resources (google.cloud.aiplatform_v1.types.DedicatedResources): + A description of resources that are dedicated + to the DeployedModel, and that need a higher + degree of manual configuration. + automatic_resources (google.cloud.aiplatform_v1.types.AutomaticResources): + A description of resources that to large + degree are decided by AI Platform, and require + only a modest additional configuration. + id (str): + Output only. The ID of the DeployedModel. + model (str): + Required. The name of the Model that this is + the deployment of. Note that the Model may be in + a different location than the DeployedModel's + Endpoint. + display_name (str): + The display name of the DeployedModel. If not provided upon + creation, the Model's display_name is used. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when the DeployedModel + was created. + service_account (str): + The service account that the DeployedModel's container runs + as. Specify the email address of the service account. If + this service account is not specified, the container runs as + a service account that doesn't have access to the resource + project. + + Users deploying the Model must have the + ``iam.serviceAccounts.actAs`` permission on this service + account. + disable_container_logging (bool): + For custom-trained Models and AutoML Tabular Models, the + container of the DeployedModel instances will send + ``stderr`` and ``stdout`` streams to Stackdriver Logging by + default. Please note that the logs incur cost, which are + subject to `Cloud Logging + pricing `__. + + User can disable container logging by setting this flag to + true. + enable_access_logging (bool): + These logs are like standard server access + logs, containing information like timestamp and + latency for each prediction request. + Note that Stackdriver logs may incur a cost, + especially if your project receives prediction + requests at a high queries per second rate + (QPS). Estimate your costs before enabling this + option. + """ + + dedicated_resources = proto.Field( + proto.MESSAGE, + number=7, + oneof="prediction_resources", + message=machine_resources.DedicatedResources, + ) + + automatic_resources = proto.Field( + proto.MESSAGE, + number=8, + oneof="prediction_resources", + message=machine_resources.AutomaticResources, + ) + + id = proto.Field(proto.STRING, number=1) + + model = proto.Field(proto.STRING, number=2) + + display_name = proto.Field(proto.STRING, number=3) + + create_time = proto.Field(proto.MESSAGE, number=6, message=timestamp.Timestamp,) + + service_account = proto.Field(proto.STRING, number=11) + + disable_container_logging = proto.Field(proto.BOOL, number=15) + + enable_access_logging = proto.Field(proto.BOOL, number=13) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/endpoint_service.py b/google/cloud/aiplatform_v1/types/endpoint_service.py new file mode 100644 index 0000000000..24e00bd486 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/endpoint_service.py @@ -0,0 +1,337 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateEndpointRequest", + "CreateEndpointOperationMetadata", + "GetEndpointRequest", + "ListEndpointsRequest", + "ListEndpointsResponse", + "UpdateEndpointRequest", + "DeleteEndpointRequest", + "DeployModelRequest", + "DeployModelResponse", + "DeployModelOperationMetadata", + "UndeployModelRequest", + "UndeployModelResponse", + "UndeployModelOperationMetadata", + }, +) + + +class CreateEndpointRequest(proto.Message): + r"""Request message for + ``EndpointService.CreateEndpoint``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + Endpoint in. Format: + ``projects/{project}/locations/{location}`` + endpoint (google.cloud.aiplatform_v1.types.Endpoint): + Required. The Endpoint to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + endpoint = proto.Field(proto.MESSAGE, number=2, message=gca_endpoint.Endpoint,) + + +class CreateEndpointOperationMetadata(proto.Message): + r"""Runtime operation information for + ``EndpointService.CreateEndpoint``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class GetEndpointRequest(proto.Message): + r"""Request message for + ``EndpointService.GetEndpoint`` + + Attributes: + name (str): + Required. The name of the Endpoint resource. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListEndpointsRequest(proto.Message): + r"""Request message for + ``EndpointService.ListEndpoints``. + + Attributes: + parent (str): + Required. The resource name of the Location from which to + list the Endpoints. Format: + ``projects/{project}/locations/{location}`` + filter (str): + Optional. An expression for filtering the results of the + request. For field names both snake_case and camelCase are + supported. + + - ``endpoint`` supports = and !=. ``endpoint`` represents + the Endpoint ID, i.e. the last segment of the Endpoint's + [resource + name][google.cloud.aiplatform.v1.Endpoint.name]. + - ``display_name`` supports = and, != + - ``labels`` supports general map functions that is: + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. + + Some examples: + + - ``endpoint=1`` + - ``displayName="myDisplayName"`` + - ``labels.myKey="myValue"`` + page_size (int): + Optional. The standard list page size. + page_token (str): + Optional. The standard list page token. Typically obtained + via + ``ListEndpointsResponse.next_page_token`` + of the previous + ``EndpointService.ListEndpoints`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Optional. Mask specifying which fields to + read. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. Supported fields: + + - ``display_name`` + - ``create_time`` + - ``update_time`` + + Example: ``display_name, create_time desc``. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListEndpointsResponse(proto.Message): + r"""Response message for + ``EndpointService.ListEndpoints``. + + Attributes: + endpoints (Sequence[google.cloud.aiplatform_v1.types.Endpoint]): + List of Endpoints in the requested page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListEndpointsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + endpoints = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_endpoint.Endpoint, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateEndpointRequest(proto.Message): + r"""Request message for + ``EndpointService.UpdateEndpoint``. + + Attributes: + endpoint (google.cloud.aiplatform_v1.types.Endpoint): + Required. The Endpoint which replaces the + resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. + """ + + endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) + + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class DeleteEndpointRequest(proto.Message): + r"""Request message for + ``EndpointService.DeleteEndpoint``. + + Attributes: + name (str): + Required. The name of the Endpoint resource to be deleted. + Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class DeployModelRequest(proto.Message): + r"""Request message for + ``EndpointService.DeployModel``. + + Attributes: + endpoint (str): + Required. The name of the Endpoint resource into which to + deploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + Required. The DeployedModel to be created within the + Endpoint. Note that + ``Endpoint.traffic_split`` + must be updated for the DeployedModel to start receiving + traffic, either as part of this call, or via + ``EndpointService.UpdateEndpoint``. + traffic_split (Sequence[google.cloud.aiplatform_v1.types.DeployModelRequest.TrafficSplitEntry]): + A map from a DeployedModel's ID to the percentage of this + Endpoint's traffic that should be forwarded to that + DeployedModel. + + If this field is non-empty, then the Endpoint's + ``traffic_split`` + will be overwritten with it. To refer to the ID of the just + being deployed Model, a "0" should be used, and the actual + ID of the new DeployedModel will be filled in its place by + this method. The traffic percentage values must add up to + 100. + + If this field is empty, then the Endpoint's + ``traffic_split`` + is not updated. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + deployed_model = proto.Field( + proto.MESSAGE, number=2, message=gca_endpoint.DeployedModel, + ) + + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) + + +class DeployModelResponse(proto.Message): + r"""Response message for + ``EndpointService.DeployModel``. + + Attributes: + deployed_model (google.cloud.aiplatform_v1.types.DeployedModel): + The DeployedModel that had been deployed in + the Endpoint. + """ + + deployed_model = proto.Field( + proto.MESSAGE, number=1, message=gca_endpoint.DeployedModel, + ) + + +class DeployModelOperationMetadata(proto.Message): + r"""Runtime operation information for + ``EndpointService.DeployModel``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class UndeployModelRequest(proto.Message): + r"""Request message for + ``EndpointService.UndeployModel``. + + Attributes: + endpoint (str): + Required. The name of the Endpoint resource from which to + undeploy a Model. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + deployed_model_id (str): + Required. The ID of the DeployedModel to be + undeployed from the Endpoint. + traffic_split (Sequence[google.cloud.aiplatform_v1.types.UndeployModelRequest.TrafficSplitEntry]): + If this field is provided, then the Endpoint's + ``traffic_split`` + will be overwritten with it. If last DeployedModel is being + undeployed from the Endpoint, the [Endpoint.traffic_split] + will always end up empty when this call returns. A + DeployedModel will be successfully undeployed only if it + doesn't have any traffic assigned to it when this method + executes, or if this field unassigns any traffic to it. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + deployed_model_id = proto.Field(proto.STRING, number=2) + + traffic_split = proto.MapField(proto.STRING, proto.INT32, number=3) + + +class UndeployModelResponse(proto.Message): + r"""Response message for + ``EndpointService.UndeployModel``. + """ + + +class UndeployModelOperationMetadata(proto.Message): + r"""Runtime operation information for + ``EndpointService.UndeployModel``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/env_var.py b/google/cloud/aiplatform_v1/types/env_var.py new file mode 100644 index 0000000000..f456c15808 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/env_var.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module(package="google.cloud.aiplatform.v1", manifest={"EnvVar",},) + + +class EnvVar(proto.Message): + r"""Represents an environment variable present in a Container or + Python Module. + + Attributes: + name (str): + Required. Name of the environment variable. + Must be a valid C identifier. + value (str): + Required. Variables that reference a $(VAR_NAME) are + expanded using the previous defined environment variables in + the container and any service environment variables. If a + variable cannot be resolved, the reference in the input + string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped + references will never be expanded, regardless of whether the + variable exists or not. + """ + + name = proto.Field(proto.STRING, number=1) + + value = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py new file mode 100644 index 0000000000..63290ff9b4 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/hyperparameter_tuning_job.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import study +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"HyperparameterTuningJob",}, +) + + +class HyperparameterTuningJob(proto.Message): + r"""Represents a HyperparameterTuningJob. A + HyperparameterTuningJob has a Study specification and multiple + CustomJobs with identical CustomJob specification. + + Attributes: + name (str): + Output only. Resource name of the + HyperparameterTuningJob. + display_name (str): + Required. The display name of the + HyperparameterTuningJob. The name can be up to + 128 characters long and can be consist of any + UTF-8 characters. + study_spec (google.cloud.aiplatform_v1.types.StudySpec): + Required. Study configuration of the + HyperparameterTuningJob. + max_trial_count (int): + Required. The desired total number of Trials. + parallel_trial_count (int): + Required. The desired number of Trials to run + in parallel. + max_failed_trial_count (int): + The number of failed Trials that need to be + seen before failing the HyperparameterTuningJob. + If set to 0, AI Platform decides how many Trials + must fail before the whole job fails. + trial_job_spec (google.cloud.aiplatform_v1.types.CustomJobSpec): + Required. The spec of a trial job. The same + spec applies to the CustomJobs created in all + the trials. + trials (Sequence[google.cloud.aiplatform_v1.types.Trial]): + Output only. Trials of the + HyperparameterTuningJob. + state (google.cloud.aiplatform_v1.types.JobState): + Output only. The detailed state of the job. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the + HyperparameterTuningJob was created. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the HyperparameterTuningJob for the + first time entered the ``JOB_STATE_RUNNING`` state. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the HyperparameterTuningJob entered + any of the following states: ``JOB_STATE_SUCCEEDED``, + ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the + HyperparameterTuningJob was most recently + updated. + error (google.rpc.status_pb2.Status): + Output only. Only populated when job's state is + JOB_STATE_FAILED or JOB_STATE_CANCELLED. + labels (Sequence[google.cloud.aiplatform_v1.types.HyperparameterTuningJob.LabelsEntry]): + The labels with user-defined metadata to + organize HyperparameterTuningJobs. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key options for a + HyperparameterTuningJob. If this is set, then + all resources created by the + HyperparameterTuningJob will be encrypted with + the provided encryption key. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + study_spec = proto.Field(proto.MESSAGE, number=4, message=study.StudySpec,) + + max_trial_count = proto.Field(proto.INT32, number=5) + + parallel_trial_count = proto.Field(proto.INT32, number=6) + + max_failed_trial_count = proto.Field(proto.INT32, number=7) + + trial_job_spec = proto.Field( + proto.MESSAGE, number=8, message=custom_job.CustomJobSpec, + ) + + trials = proto.RepeatedField(proto.MESSAGE, number=9, message=study.Trial,) + + state = proto.Field(proto.ENUM, number=10, enum=job_state.JobState,) + + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + + error = proto.Field(proto.MESSAGE, number=15, message=status.Status,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=16) + + encryption_spec = proto.Field( + proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/io.py b/google/cloud/aiplatform_v1/types/io.py new file mode 100644 index 0000000000..1a75ea33bc --- /dev/null +++ b/google/cloud/aiplatform_v1/types/io.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "GcsSource", + "GcsDestination", + "BigQuerySource", + "BigQueryDestination", + "ContainerRegistryDestination", + }, +) + + +class GcsSource(proto.Message): + r"""The Google Cloud Storage location for the input content. + + Attributes: + uris (Sequence[str]): + Required. Google Cloud Storage URI(-s) to the + input file(s). May contain wildcards. For more + information on wildcards, see + https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + """ + + uris = proto.RepeatedField(proto.STRING, number=1) + + +class GcsDestination(proto.Message): + r"""The Google Cloud Storage location where the output is to be + written to. + + Attributes: + output_uri_prefix (str): + Required. Google Cloud Storage URI to output + directory. If the uri doesn't end with '/', a + '/' will be automatically appended. The + directory is created if it doesn't exist. + """ + + output_uri_prefix = proto.Field(proto.STRING, number=1) + + +class BigQuerySource(proto.Message): + r"""The BigQuery location for the input content. + + Attributes: + input_uri (str): + Required. BigQuery URI to a table, up to 2000 characters + long. Accepted forms: + + - BigQuery path. For example: + ``bq://projectId.bqDatasetId.bqTableId``. + """ + + input_uri = proto.Field(proto.STRING, number=1) + + +class BigQueryDestination(proto.Message): + r"""The BigQuery location for the output content. + + Attributes: + output_uri (str): + Required. BigQuery URI to a project or table, up to 2000 + characters long. + + When only the project is specified, the Dataset and Table + are created. When the full table reference is specified, the + Dataset must exist and table must not exist. + + Accepted forms: + + - BigQuery path. For example: ``bq://projectId`` or + ``bq://projectId.bqDatasetId.bqTableId``. + """ + + output_uri = proto.Field(proto.STRING, number=1) + + +class ContainerRegistryDestination(proto.Message): + r"""The Container Registry location for the container image. + + Attributes: + output_uri (str): + Required. Container Registry URI of a container image. Only + Google Container Registry and Artifact Registry are + supported now. Accepted forms: + + - Google Container Registry path. For example: + ``gcr.io/projectId/imageName:tag``. + + - Artifact Registry path. For example: + ``us-central1-docker.pkg.dev/projectId/repoName/imageName:tag``. + + If a tag is not specified, "latest" will be used as the + default tag. + """ + + output_uri = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/job_service.py b/google/cloud/aiplatform_v1/types/job_service.py new file mode 100644 index 0000000000..3a6d844ea7 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/job_service.py @@ -0,0 +1,619 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateCustomJobRequest", + "GetCustomJobRequest", + "ListCustomJobsRequest", + "ListCustomJobsResponse", + "DeleteCustomJobRequest", + "CancelCustomJobRequest", + "CreateDataLabelingJobRequest", + "GetDataLabelingJobRequest", + "ListDataLabelingJobsRequest", + "ListDataLabelingJobsResponse", + "DeleteDataLabelingJobRequest", + "CancelDataLabelingJobRequest", + "CreateHyperparameterTuningJobRequest", + "GetHyperparameterTuningJobRequest", + "ListHyperparameterTuningJobsRequest", + "ListHyperparameterTuningJobsResponse", + "DeleteHyperparameterTuningJobRequest", + "CancelHyperparameterTuningJobRequest", + "CreateBatchPredictionJobRequest", + "GetBatchPredictionJobRequest", + "ListBatchPredictionJobsRequest", + "ListBatchPredictionJobsResponse", + "DeleteBatchPredictionJobRequest", + "CancelBatchPredictionJobRequest", + }, +) + + +class CreateCustomJobRequest(proto.Message): + r"""Request message for + ``JobService.CreateCustomJob``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + CustomJob in. Format: + ``projects/{project}/locations/{location}`` + custom_job (google.cloud.aiplatform_v1.types.CustomJob): + Required. The CustomJob to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + custom_job = proto.Field(proto.MESSAGE, number=2, message=gca_custom_job.CustomJob,) + + +class GetCustomJobRequest(proto.Message): + r"""Request message for + ``JobService.GetCustomJob``. + + Attributes: + name (str): + Required. The name of the CustomJob resource. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListCustomJobsRequest(proto.Message): + r"""Request message for + ``JobService.ListCustomJobs``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + CustomJobs from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + + Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="JOB_STATE_SUCCEEDED" AND display_name="my_job"`` + + - ``state="JOB_STATE_RUNNING" OR display_name="my_job"`` + + - ``NOT display_name="my_job"`` + + - ``state="JOB_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListCustomJobsResponse.next_page_token`` + of the previous + ``JobService.ListCustomJobs`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListCustomJobsResponse(proto.Message): + r"""Response message for + ``JobService.ListCustomJobs`` + + Attributes: + custom_jobs (Sequence[google.cloud.aiplatform_v1.types.CustomJob]): + List of CustomJobs in the requested page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListCustomJobsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + custom_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_custom_job.CustomJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteCustomJobRequest(proto.Message): + r"""Request message for + ``JobService.DeleteCustomJob``. + + Attributes: + name (str): + Required. The name of the CustomJob resource to be deleted. + Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelCustomJobRequest(proto.Message): + r"""Request message for + ``JobService.CancelCustomJob``. + + Attributes: + name (str): + Required. The name of the CustomJob to cancel. Format: + ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateDataLabelingJobRequest(proto.Message): + r"""Request message for + [DataLabelingJobService.CreateDataLabelingJob][]. + + Attributes: + parent (str): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + data_labeling_job (google.cloud.aiplatform_v1.types.DataLabelingJob): + Required. The DataLabelingJob to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + data_labeling_job = proto.Field( + proto.MESSAGE, number=2, message=gca_data_labeling_job.DataLabelingJob, + ) + + +class GetDataLabelingJobRequest(proto.Message): + r"""Request message for [DataLabelingJobService.GetDataLabelingJob][]. + + Attributes: + name (str): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListDataLabelingJobsRequest(proto.Message): + r"""Request message for [DataLabelingJobService.ListDataLabelingJobs][]. + + Attributes: + parent (str): + Required. The parent of the DataLabelingJob. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + + Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="JOB_STATE_SUCCEEDED" AND display_name="my_job"`` + + - ``state="JOB_STATE_RUNNING" OR display_name="my_job"`` + + - ``NOT display_name="my_job"`` + + - ``state="JOB_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. FieldMask represents a + set of symbolic field paths. For example, the mask can be + ``paths: "name"``. The "name" here is a field in + DataLabelingJob. If this field is not set, all fields of the + DataLabelingJob are returned. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order by default. Use ``desc`` after a field name + for descending. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListDataLabelingJobsResponse(proto.Message): + r"""Response message for + ``JobService.ListDataLabelingJobs``. + + Attributes: + data_labeling_jobs (Sequence[google.cloud.aiplatform_v1.types.DataLabelingJob]): + A list of DataLabelingJobs that matches the + specified filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + data_labeling_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_data_labeling_job.DataLabelingJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteDataLabelingJobRequest(proto.Message): + r"""Request message for + ``JobService.DeleteDataLabelingJob``. + + Attributes: + name (str): + Required. The name of the DataLabelingJob to be deleted. + Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelDataLabelingJobRequest(proto.Message): + r"""Request message for + [DataLabelingJobService.CancelDataLabelingJob][]. + + Attributes: + name (str): + Required. The name of the DataLabelingJob. Format: + + ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateHyperparameterTuningJobRequest(proto.Message): + r"""Request message for + ``JobService.CreateHyperparameterTuningJob``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + HyperparameterTuningJob in. Format: + ``projects/{project}/locations/{location}`` + hyperparameter_tuning_job (google.cloud.aiplatform_v1.types.HyperparameterTuningJob): + Required. The HyperparameterTuningJob to + create. + """ + + parent = proto.Field(proto.STRING, number=1) + + hyperparameter_tuning_job = proto.Field( + proto.MESSAGE, + number=2, + message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ) + + +class GetHyperparameterTuningJobRequest(proto.Message): + r"""Request message for + ``JobService.GetHyperparameterTuningJob``. + + Attributes: + name (str): + Required. The name of the HyperparameterTuningJob resource. + Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListHyperparameterTuningJobsRequest(proto.Message): + r"""Request message for + ``JobService.ListHyperparameterTuningJobs``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + HyperparameterTuningJobs from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + + Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="JOB_STATE_SUCCEEDED" AND display_name="my_job"`` + + - ``state="JOB_STATE_RUNNING" OR display_name="my_job"`` + + - ``NOT display_name="my_job"`` + + - ``state="JOB_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListHyperparameterTuningJobsResponse.next_page_token`` + of the previous + ``JobService.ListHyperparameterTuningJobs`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListHyperparameterTuningJobsResponse(proto.Message): + r"""Response message for + ``JobService.ListHyperparameterTuningJobs`` + + Attributes: + hyperparameter_tuning_jobs (Sequence[google.cloud.aiplatform_v1.types.HyperparameterTuningJob]): + List of HyperparameterTuningJobs in the requested page. + ``HyperparameterTuningJob.trials`` + of the jobs will be not be returned. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListHyperparameterTuningJobsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + hyperparameter_tuning_jobs = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gca_hyperparameter_tuning_job.HyperparameterTuningJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteHyperparameterTuningJobRequest(proto.Message): + r"""Request message for + ``JobService.DeleteHyperparameterTuningJob``. + + Attributes: + name (str): + Required. The name of the HyperparameterTuningJob resource + to be deleted. Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelHyperparameterTuningJobRequest(proto.Message): + r"""Request message for + ``JobService.CancelHyperparameterTuningJob``. + + Attributes: + name (str): + Required. The name of the HyperparameterTuningJob to cancel. + Format: + + ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CreateBatchPredictionJobRequest(proto.Message): + r"""Request message for + ``JobService.CreateBatchPredictionJob``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + BatchPredictionJob in. Format: + ``projects/{project}/locations/{location}`` + batch_prediction_job (google.cloud.aiplatform_v1.types.BatchPredictionJob): + Required. The BatchPredictionJob to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + batch_prediction_job = proto.Field( + proto.MESSAGE, number=2, message=gca_batch_prediction_job.BatchPredictionJob, + ) + + +class GetBatchPredictionJobRequest(proto.Message): + r"""Request message for + ``JobService.GetBatchPredictionJob``. + + Attributes: + name (str): + Required. The name of the BatchPredictionJob resource. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListBatchPredictionJobsRequest(proto.Message): + r"""Request message for + ``JobService.ListBatchPredictionJobs``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + BatchPredictionJobs from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. + + Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="JOB_STATE_SUCCEEDED" AND display_name="my_job"`` + + - ``state="JOB_STATE_RUNNING" OR display_name="my_job"`` + + - ``NOT display_name="my_job"`` + + - ``state="JOB_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListBatchPredictionJobsResponse.next_page_token`` + of the previous + ``JobService.ListBatchPredictionJobs`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListBatchPredictionJobsResponse(proto.Message): + r"""Response message for + ``JobService.ListBatchPredictionJobs`` + + Attributes: + batch_prediction_jobs (Sequence[google.cloud.aiplatform_v1.types.BatchPredictionJob]): + List of BatchPredictionJobs in the requested + page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListBatchPredictionJobsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + batch_prediction_jobs = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_batch_prediction_job.BatchPredictionJob, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteBatchPredictionJobRequest(proto.Message): + r"""Request message for + ``JobService.DeleteBatchPredictionJob``. + + Attributes: + name (str): + Required. The name of the BatchPredictionJob resource to be + deleted. Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelBatchPredictionJobRequest(proto.Message): + r"""Request message for + ``JobService.CancelBatchPredictionJob``. + + Attributes: + name (str): + Required. The name of the BatchPredictionJob to cancel. + Format: + + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/job_state.py b/google/cloud/aiplatform_v1/types/job_state.py new file mode 100644 index 0000000000..40b1694f86 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/job_state.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"JobState",}, +) + + +class JobState(proto.Enum): + r"""Describes the state of a job.""" + JOB_STATE_UNSPECIFIED = 0 + JOB_STATE_QUEUED = 1 + JOB_STATE_PENDING = 2 + JOB_STATE_RUNNING = 3 + JOB_STATE_SUCCEEDED = 4 + JOB_STATE_FAILED = 5 + JOB_STATE_CANCELLING = 6 + JOB_STATE_CANCELLED = 7 + JOB_STATE_PAUSED = 8 + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/machine_resources.py b/google/cloud/aiplatform_v1/types/machine_resources.py new file mode 100644 index 0000000000..f6864eb798 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/machine_resources.py @@ -0,0 +1,212 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import accelerator_type as gca_accelerator_type + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "MachineSpec", + "DedicatedResources", + "AutomaticResources", + "BatchDedicatedResources", + "ResourcesConsumed", + "DiskSpec", + }, +) + + +class MachineSpec(proto.Message): + r"""Specification of a single machine. + + Attributes: + machine_type (str): + Immutable. The type of the machine. For the machine types + supported for prediction, see + https://tinyurl.com/aip-docs/predictions/machine-types. For + machine types supported for creating a custom training job, + see https://tinyurl.com/aip-docs/training/configure-compute. + + For + ``DeployedModel`` + this field is optional, and the default value is + ``n1-standard-2``. For + ``BatchPredictionJob`` + or as part of + ``WorkerPoolSpec`` + this field is required. + accelerator_type (google.cloud.aiplatform_v1.types.AcceleratorType): + Immutable. The type of accelerator(s) that may be attached + to the machine as per + ``accelerator_count``. + accelerator_count (int): + The number of accelerators to attach to the + machine. + """ + + machine_type = proto.Field(proto.STRING, number=1) + + accelerator_type = proto.Field( + proto.ENUM, number=2, enum=gca_accelerator_type.AcceleratorType, + ) + + accelerator_count = proto.Field(proto.INT32, number=3) + + +class DedicatedResources(proto.Message): + r"""A description of resources that are dedicated to a + DeployedModel, and that need a higher degree of manual + configuration. + + Attributes: + machine_spec (google.cloud.aiplatform_v1.types.MachineSpec): + Required. Immutable. The specification of a + single machine used by the prediction. + min_replica_count (int): + Required. Immutable. The minimum number of machine replicas + this DeployedModel will be always deployed on. If traffic + against it increases, it may dynamically be deployed onto + more replicas, and as traffic decreases, some of these extra + replicas may be freed. Note: if + ``machine_spec.accelerator_count`` + is above 0, currently the model will be always deployed + precisely on + ``min_replica_count``. + max_replica_count (int): + Immutable. The maximum number of replicas this DeployedModel + may be deployed on when the traffic against it increases. If + the requested value is too large, the deployment will error, + but if deployment succeeds then the ability to scale the + model to that many replicas is guaranteed (barring service + outages). If traffic against the DeployedModel increases + beyond what its replicas at maximum may handle, a portion of + the traffic will be dropped. If this value is not provided, + will use + ``min_replica_count`` + as the default value. + """ + + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + + min_replica_count = proto.Field(proto.INT32, number=2) + + max_replica_count = proto.Field(proto.INT32, number=3) + + +class AutomaticResources(proto.Message): + r"""A description of resources that to large degree are decided + by AI Platform, and require only a modest additional + configuration. Each Model supporting these resources documents + its specific guidelines. + + Attributes: + min_replica_count (int): + Immutable. The minimum number of replicas this DeployedModel + will be always deployed on. If traffic against it increases, + it may dynamically be deployed onto more replicas up to + ``max_replica_count``, + and as traffic decreases, some of these extra replicas may + be freed. If the requested value is too large, the + deployment will error. + max_replica_count (int): + Immutable. The maximum number of replicas + this DeployedModel may be deployed on when the + traffic against it increases. If the requested + value is too large, the deployment will error, + but if deployment succeeds then the ability to + scale the model to that many replicas is + guaranteed (barring service outages). If traffic + against the DeployedModel increases beyond what + its replicas at maximum may handle, a portion of + the traffic will be dropped. If this value is + not provided, a no upper bound for scaling under + heavy traffic will be assume, though AI Platform + may be unable to scale beyond certain replica + number. + """ + + min_replica_count = proto.Field(proto.INT32, number=1) + + max_replica_count = proto.Field(proto.INT32, number=2) + + +class BatchDedicatedResources(proto.Message): + r"""A description of resources that are used for performing batch + operations, are dedicated to a Model, and need manual + configuration. + + Attributes: + machine_spec (google.cloud.aiplatform_v1.types.MachineSpec): + Required. Immutable. The specification of a + single machine. + starting_replica_count (int): + Immutable. The number of machine replicas used at the start + of the batch operation. If not set, AI Platform decides + starting number, not greater than + ``max_replica_count`` + max_replica_count (int): + Immutable. The maximum number of machine + replicas the batch operation may be scaled to. + The default value is 10. + """ + + machine_spec = proto.Field(proto.MESSAGE, number=1, message="MachineSpec",) + + starting_replica_count = proto.Field(proto.INT32, number=2) + + max_replica_count = proto.Field(proto.INT32, number=3) + + +class ResourcesConsumed(proto.Message): + r"""Statistics information about resource consumption. + + Attributes: + replica_hours (float): + Output only. The number of replica hours + used. Note that many replicas may run in + parallel, and additionally any given work may be + queued for some time. Therefore this value is + not strictly related to wall time. + """ + + replica_hours = proto.Field(proto.DOUBLE, number=1) + + +class DiskSpec(proto.Message): + r"""Represents the spec of disk options. + + Attributes: + boot_disk_type (str): + Type of the boot disk (default is "pd-ssd"). + Valid values: "pd-ssd" (Persistent Disk Solid + State Drive) or "pd-standard" (Persistent Disk + Hard Disk Drive). + boot_disk_size_gb (int): + Size in GB of the boot disk (default is + 100GB). + """ + + boot_disk_type = proto.Field(proto.STRING, number=1) + + boot_disk_size_gb = proto.Field(proto.INT32, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py new file mode 100644 index 0000000000..7500d618a0 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/manual_batch_tuning_parameters.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"ManualBatchTuningParameters",}, +) + + +class ManualBatchTuningParameters(proto.Message): + r"""Manual batch tuning parameters. + + Attributes: + batch_size (int): + Immutable. The number of the records (e.g. + instances) of the operation given in each batch + to a machine replica. Machine type, and size of + a single record should be considered when + setting this parameter, higher value speeds up + the batch operation's execution, but too high + value will result in a whole batch not fitting + in a machine's memory, and the whole operation + will fail. + The default value is 4. + """ + + batch_size = proto.Field(proto.INT32, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/migratable_resource.py b/google/cloud/aiplatform_v1/types/migratable_resource.py new file mode 100644 index 0000000000..652a835c89 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/migratable_resource.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"MigratableResource",}, +) + + +class MigratableResource(proto.Message): + r"""Represents one resource that exists in automl.googleapis.com, + datalabeling.googleapis.com or ml.googleapis.com. + + Attributes: + ml_engine_model_version (google.cloud.aiplatform_v1.types.MigratableResource.MlEngineModelVersion): + Output only. Represents one Version in + ml.googleapis.com. + automl_model (google.cloud.aiplatform_v1.types.MigratableResource.AutomlModel): + Output only. Represents one Model in + automl.googleapis.com. + automl_dataset (google.cloud.aiplatform_v1.types.MigratableResource.AutomlDataset): + Output only. Represents one Dataset in + automl.googleapis.com. + data_labeling_dataset (google.cloud.aiplatform_v1.types.MigratableResource.DataLabelingDataset): + Output only. Represents one Dataset in + datalabeling.googleapis.com. + last_migrate_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when the last + migration attempt on this MigratableResource + started. Will not be set if there's no migration + attempt on this MigratableResource. + last_update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + MigratableResource was last updated. + """ + + class MlEngineModelVersion(proto.Message): + r"""Represents one model Version in ml.googleapis.com. + + Attributes: + endpoint (str): + The ml.googleapis.com endpoint that this model Version + currently lives in. Example values: + + - ml.googleapis.com + - us-centrall-ml.googleapis.com + - europe-west4-ml.googleapis.com + - asia-east1-ml.googleapis.com + version (str): + Full resource name of ml engine model Version. Format: + ``projects/{project}/models/{model}/versions/{version}``. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + version = proto.Field(proto.STRING, number=2) + + class AutomlModel(proto.Message): + r"""Represents one Model in automl.googleapis.com. + + Attributes: + model (str): + Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + The Model's display name in + automl.googleapis.com. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=3) + + class AutomlDataset(proto.Message): + r"""Represents one Dataset in automl.googleapis.com. + + Attributes: + dataset (str): + Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + automl.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + class DataLabelingDataset(proto.Message): + r"""Represents one Dataset in datalabeling.googleapis.com. + + Attributes: + dataset (str): + Full resource name of data labeling Dataset. Format: + ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + The Dataset's display name in + datalabeling.googleapis.com. + data_labeling_annotated_datasets (Sequence[google.cloud.aiplatform_v1.types.MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset]): + The migratable AnnotatedDataset in + datalabeling.googleapis.com belongs to the data + labeling Dataset. + """ + + class DataLabelingAnnotatedDataset(proto.Message): + r"""Represents one AnnotatedDataset in + datalabeling.googleapis.com. + + Attributes: + annotated_dataset (str): + Full resource name of data labeling AnnotatedDataset. + Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + annotated_dataset_display_name (str): + The AnnotatedDataset's display name in + datalabeling.googleapis.com. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + annotated_dataset_display_name = proto.Field(proto.STRING, number=3) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=4) + + data_labeling_annotated_datasets = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset", + ) + + ml_engine_model_version = proto.Field( + proto.MESSAGE, number=1, oneof="resource", message=MlEngineModelVersion, + ) + + automl_model = proto.Field( + proto.MESSAGE, number=2, oneof="resource", message=AutomlModel, + ) + + automl_dataset = proto.Field( + proto.MESSAGE, number=3, oneof="resource", message=AutomlDataset, + ) + + data_labeling_dataset = proto.Field( + proto.MESSAGE, number=4, oneof="resource", message=DataLabelingDataset, + ) + + last_migrate_time = proto.Field( + proto.MESSAGE, number=5, message=timestamp.Timestamp, + ) + + last_update_time = proto.Field( + proto.MESSAGE, number=6, message=timestamp.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/migration_service.py b/google/cloud/aiplatform_v1/types/migration_service.py new file mode 100644 index 0000000000..acd69b37b4 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/migration_service.py @@ -0,0 +1,376 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import ( + migratable_resource as gca_migratable_resource, +) +from google.cloud.aiplatform_v1.types import operation +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "SearchMigratableResourcesRequest", + "SearchMigratableResourcesResponse", + "BatchMigrateResourcesRequest", + "MigrateResourceRequest", + "BatchMigrateResourcesResponse", + "MigrateResourceResponse", + "BatchMigrateResourcesOperationMetadata", + }, +) + + +class SearchMigratableResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + parent (str): + Required. The location that the migratable resources should + be searched from. It's the AI Platform location that the + resources can be migrated to, not the resources' original + location. Format: + ``projects/{project}/locations/{location}`` + page_size (int): + The standard page size. + The default and maximum value is 100. + page_token (str): + The standard page token. + filter (str): + Supported filters are: + + - Resource type: For a specific type of MigratableResource. + + - ``ml_engine_model_version:*`` + - ``automl_model:*``, + - ``automl_dataset:*`` + - ``data_labeling_dataset:*``. + + - Migrated or not: Filter migrated resource or not by + last_migrate_time. + + - ``last_migrate_time:*`` will filter migrated + resources. + - ``NOT last_migrate_time:*`` will filter not yet + migrated resources. + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + filter = proto.Field(proto.STRING, number=4) + + +class SearchMigratableResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.SearchMigratableResources``. + + Attributes: + migratable_resources (Sequence[google.cloud.aiplatform_v1.types.MigratableResource]): + All migratable resources that can be migrated + to the location specified in the request. + next_page_token (str): + The standard next-page token. The migratable_resources may + not fill page_size in SearchMigratableResourcesRequest even + when there are subsequent pages. + """ + + @property + def raw_page(self): + return self + + migratable_resources = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_migratable_resource.MigratableResource, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class BatchMigrateResourcesRequest(proto.Message): + r"""Request message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + parent (str): + Required. The location of the migrated resource will live + in. Format: ``projects/{project}/locations/{location}`` + migrate_resource_requests (Sequence[google.cloud.aiplatform_v1.types.MigrateResourceRequest]): + Required. The request messages specifying the + resources to migrate. They must be in the same + location as the destination. Up to 50 resources + can be migrated in one batch. + """ + + parent = proto.Field(proto.STRING, number=1) + + migrate_resource_requests = proto.RepeatedField( + proto.MESSAGE, number=2, message="MigrateResourceRequest", + ) + + +class MigrateResourceRequest(proto.Message): + r"""Config of migrating one resource from automl.googleapis.com, + datalabeling.googleapis.com and ml.googleapis.com to AI + Platform. + + Attributes: + migrate_ml_engine_model_version_config (google.cloud.aiplatform_v1.types.MigrateResourceRequest.MigrateMlEngineModelVersionConfig): + Config for migrating Version in + ml.googleapis.com to AI Platform's Model. + migrate_automl_model_config (google.cloud.aiplatform_v1.types.MigrateResourceRequest.MigrateAutomlModelConfig): + Config for migrating Model in + automl.googleapis.com to AI Platform's Model. + migrate_automl_dataset_config (google.cloud.aiplatform_v1.types.MigrateResourceRequest.MigrateAutomlDatasetConfig): + Config for migrating Dataset in + automl.googleapis.com to AI Platform's Dataset. + migrate_data_labeling_dataset_config (google.cloud.aiplatform_v1.types.MigrateResourceRequest.MigrateDataLabelingDatasetConfig): + Config for migrating Dataset in + datalabeling.googleapis.com to AI Platform's + Dataset. + """ + + class MigrateMlEngineModelVersionConfig(proto.Message): + r"""Config for migrating version in ml.googleapis.com to AI + Platform's Model. + + Attributes: + endpoint (str): + Required. The ml.googleapis.com endpoint that this model + version should be migrated from. Example values: + + - ml.googleapis.com + + - us-centrall-ml.googleapis.com + + - europe-west4-ml.googleapis.com + + - asia-east1-ml.googleapis.com + model_version (str): + Required. Full resource name of ml engine model version. + Format: + ``projects/{project}/models/{model}/versions/{version}``. + model_display_name (str): + Required. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + model_version = proto.Field(proto.STRING, number=2) + + model_display_name = proto.Field(proto.STRING, number=3) + + class MigrateAutomlModelConfig(proto.Message): + r"""Config for migrating Model in automl.googleapis.com to AI + Platform's Model. + + Attributes: + model (str): + Required. Full resource name of automl Model. Format: + ``projects/{project}/locations/{location}/models/{model}``. + model_display_name (str): + Optional. Display name of the model in AI + Platform. System will pick a display name if + unspecified. + """ + + model = proto.Field(proto.STRING, number=1) + + model_display_name = proto.Field(proto.STRING, number=2) + + class MigrateAutomlDatasetConfig(proto.Message): + r"""Config for migrating Dataset in automl.googleapis.com to AI + Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of automl Dataset. Format: + ``projects/{project}/locations/{location}/datasets/{dataset}``. + dataset_display_name (str): + Required. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + """ + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + class MigrateDataLabelingDatasetConfig(proto.Message): + r"""Config for migrating Dataset in datalabeling.googleapis.com + to AI Platform's Dataset. + + Attributes: + dataset (str): + Required. Full resource name of data labeling Dataset. + Format: ``projects/{project}/datasets/{dataset}``. + dataset_display_name (str): + Optional. Display name of the Dataset in AI + Platform. System will pick a display name if + unspecified. + migrate_data_labeling_annotated_dataset_configs (Sequence[google.cloud.aiplatform_v1.types.MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig]): + Optional. Configs for migrating + AnnotatedDataset in datalabeling.googleapis.com + to AI Platform's SavedQuery. The specified + AnnotatedDatasets have to belong to the + datalabeling Dataset. + """ + + class MigrateDataLabelingAnnotatedDatasetConfig(proto.Message): + r"""Config for migrating AnnotatedDataset in + datalabeling.googleapis.com to AI Platform's SavedQuery. + + Attributes: + annotated_dataset (str): + Required. Full resource name of data labeling + AnnotatedDataset. Format: + + ``projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}``. + """ + + annotated_dataset = proto.Field(proto.STRING, number=1) + + dataset = proto.Field(proto.STRING, number=1) + + dataset_display_name = proto.Field(proto.STRING, number=2) + + migrate_data_labeling_annotated_dataset_configs = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig", + ) + + migrate_ml_engine_model_version_config = proto.Field( + proto.MESSAGE, + number=1, + oneof="request", + message=MigrateMlEngineModelVersionConfig, + ) + + migrate_automl_model_config = proto.Field( + proto.MESSAGE, number=2, oneof="request", message=MigrateAutomlModelConfig, + ) + + migrate_automl_dataset_config = proto.Field( + proto.MESSAGE, number=3, oneof="request", message=MigrateAutomlDatasetConfig, + ) + + migrate_data_labeling_dataset_config = proto.Field( + proto.MESSAGE, + number=4, + oneof="request", + message=MigrateDataLabelingDatasetConfig, + ) + + +class BatchMigrateResourcesResponse(proto.Message): + r"""Response message for + ``MigrationService.BatchMigrateResources``. + + Attributes: + migrate_resource_responses (Sequence[google.cloud.aiplatform_v1.types.MigrateResourceResponse]): + Successfully migrated resources. + """ + + migrate_resource_responses = proto.RepeatedField( + proto.MESSAGE, number=1, message="MigrateResourceResponse", + ) + + +class MigrateResourceResponse(proto.Message): + r"""Describes a successfully migrated resource. + + Attributes: + dataset (str): + Migrated Dataset's resource name. + model (str): + Migrated Model's resource name. + migratable_resource (google.cloud.aiplatform_v1.types.MigratableResource): + Before migration, the identifier in + ml.googleapis.com, automl.googleapis.com or + datalabeling.googleapis.com. + """ + + dataset = proto.Field(proto.STRING, number=1, oneof="migrated_resource") + + model = proto.Field(proto.STRING, number=2, oneof="migrated_resource") + + migratable_resource = proto.Field( + proto.MESSAGE, number=3, message=gca_migratable_resource.MigratableResource, + ) + + +class BatchMigrateResourcesOperationMetadata(proto.Message): + r"""Runtime operation information for + ``MigrationService.BatchMigrateResources``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + partial_results (Sequence[google.cloud.aiplatform_v1.types.BatchMigrateResourcesOperationMetadata.PartialResult]): + Partial results that reflect the latest + migration operation progress. + """ + + class PartialResult(proto.Message): + r"""Represents a partial result in batch migration operation for one + ``MigrateResourceRequest``. + + Attributes: + error (google.rpc.status_pb2.Status): + The error result of the migration request in + case of failure. + model (str): + Migrated model resource name. + dataset (str): + Migrated dataset resource name. + request (google.cloud.aiplatform_v1.types.MigrateResourceRequest): + It's the same as the value in + [MigrateResourceRequest.migrate_resource_requests][]. + """ + + error = proto.Field( + proto.MESSAGE, number=2, oneof="result", message=status.Status, + ) + + model = proto.Field(proto.STRING, number=3, oneof="result") + + dataset = proto.Field(proto.STRING, number=4, oneof="result") + + request = proto.Field( + proto.MESSAGE, number=1, message="MigrateResourceRequest", + ) + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + partial_results = proto.RepeatedField( + proto.MESSAGE, number=2, message=PartialResult, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py new file mode 100644 index 0000000000..c2db797b98 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/model.py @@ -0,0 +1,630 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import deployed_model_ref +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import env_var +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={"Model", "PredictSchemata", "ModelContainerSpec", "Port",}, +) + + +class Model(proto.Message): + r"""A trained machine learning Model. + + Attributes: + name (str): + The resource name of the Model. + display_name (str): + Required. The display name of the Model. + The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + description (str): + The description of the Model. + predict_schemata (google.cloud.aiplatform_v1.types.PredictSchemata): + The schemata that describe formats of the Model's + predictions and explanations as given and returned via + ``PredictionService.Predict`` + and [PredictionService.Explain][]. + metadata_schema_uri (str): + Immutable. Points to a YAML file stored on Google Cloud + Storage describing additional information about the Model, + that is specific to it. Unset if the Model does not have any + additional information. The schema is defined as an OpenAPI + 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no additional metadata is needed, this field is + set to an empty string. Note: The URI given on output will + be immutable and probably different, including the URI + scheme, than the one given on input. The output URI will + point to a location where the user only has a read access. + metadata (google.protobuf.struct_pb2.Value): + Immutable. An additional information about the Model; the + schema of the metadata can be found in + ``metadata_schema``. + Unset if the Model does not have any additional information. + supported_export_formats (Sequence[google.cloud.aiplatform_v1.types.Model.ExportFormat]): + Output only. The formats in which this Model + may be exported. If empty, this Model is not + available for export. + training_pipeline (str): + Output only. The resource name of the + TrainingPipeline that uploaded this Model, if + any. + container_spec (google.cloud.aiplatform_v1.types.ModelContainerSpec): + Input only. The specification of the container that is to be + used when deploying this Model. The specification is + ingested upon + ``ModelService.UploadModel``, + and all binaries it contains are copied and stored + internally by AI Platform. Not present for AutoML Models. + artifact_uri (str): + Immutable. The path to the directory + containing the Model artifact and any of its + supporting files. Not present for AutoML Models. + supported_deployment_resources_types (Sequence[google.cloud.aiplatform_v1.types.Model.DeploymentResourcesType]): + Output only. When this Model is deployed, its prediction + resources are described by the ``prediction_resources`` + field of the + ``Endpoint.deployed_models`` + object. Because not all Models support all resource + configuration types, the configuration types this Model + supports are listed here. If no configuration types are + listed, the Model cannot be deployed to an + ``Endpoint`` and does not + support online predictions + (``PredictionService.Predict`` + or [PredictionService.Explain][]). Such a Model can serve + predictions by using a + ``BatchPredictionJob``, + if it has at least one entry each in + ``supported_input_storage_formats`` + and + ``supported_output_storage_formats``. + supported_input_storage_formats (Sequence[str]): + Output only. The formats this Model supports in + ``BatchPredictionJob.input_config``. + If + ``PredictSchemata.instance_schema_uri`` + exists, the instances should be given as per that schema. + + The possible formats are: + + - ``jsonl`` The JSON Lines format, where each instance is a + single line. Uses + ``GcsSource``. + + - ``csv`` The CSV format, where each instance is a single + comma-separated line. The first line in the file is the + header, containing comma-separated field names. Uses + ``GcsSource``. + + - ``tf-record`` The TFRecord format, where each instance is + a single record in tfrecord syntax. Uses + ``GcsSource``. + + - ``tf-record-gzip`` Similar to ``tf-record``, but the file + is gzipped. Uses + ``GcsSource``. + + - ``bigquery`` Each instance is a single row in BigQuery. + Uses + ``BigQuerySource``. + + - ``file-list`` Each line of the file is the location of an + instance to process, uses ``gcs_source`` field of the + ``InputConfig`` + object. + + If this Model doesn't support any of these formats it means + it cannot be used with a + ``BatchPredictionJob``. + However, if it has + ``supported_deployment_resources_types``, + it could serve online predictions by using + ``PredictionService.Predict`` + or [PredictionService.Explain][]. + supported_output_storage_formats (Sequence[str]): + Output only. The formats this Model supports in + ``BatchPredictionJob.output_config``. + If both + ``PredictSchemata.instance_schema_uri`` + and + ``PredictSchemata.prediction_schema_uri`` + exist, the predictions are returned together with their + instances. In other words, the prediction has the original + instance data first, followed by the actual prediction + content (as per the schema). + + The possible formats are: + + - ``jsonl`` The JSON Lines format, where each prediction is + a single line. Uses + ``GcsDestination``. + + - ``csv`` The CSV format, where each prediction is a single + comma-separated line. The first line in the file is the + header, containing comma-separated field names. Uses + ``GcsDestination``. + + - ``bigquery`` Each prediction is a single row in a + BigQuery table, uses + ``BigQueryDestination`` + . + + If this Model doesn't support any of these formats it means + it cannot be used with a + ``BatchPredictionJob``. + However, if it has + ``supported_deployment_resources_types``, + it could serve online predictions by using + ``PredictionService.Predict`` + or [PredictionService.Explain][]. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Model was + uploaded into AI Platform. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this Model was + most recently updated. + deployed_models (Sequence[google.cloud.aiplatform_v1.types.DeployedModelRef]): + Output only. The pointers to DeployedModels + created from this Model. Note that Model could + have been deployed to Endpoints in different + Locations. + etag (str): + Used to perform consistent read-modify-write + updates. If not set, a blind "overwrite" update + happens. + labels (Sequence[google.cloud.aiplatform_v1.types.Model.LabelsEntry]): + The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key spec for a + Model. If set, this Model and all sub-resources + of this Model will be secured by this key. + """ + + class DeploymentResourcesType(proto.Enum): + r"""Identifies a type of Model's prediction resources.""" + DEPLOYMENT_RESOURCES_TYPE_UNSPECIFIED = 0 + DEDICATED_RESOURCES = 1 + AUTOMATIC_RESOURCES = 2 + + class ExportFormat(proto.Message): + r"""Represents export format supported by the Model. + All formats export to Google Cloud Storage. + + Attributes: + id (str): + Output only. The ID of the export format. The possible + format IDs are: + + - ``tflite`` Used for Android mobile devices. + + - ``edgetpu-tflite`` Used for `Edge + TPU `__ devices. + + - ``tf-saved-model`` A tensorflow model in SavedModel + format. + + - ``tf-js`` A + `TensorFlow.js `__ model + that can be used in the browser and in Node.js using + JavaScript. + + - ``core-ml`` Used for iOS mobile devices. + + - ``custom-trained`` A Model that was uploaded or trained + by custom code. + exportable_contents (Sequence[google.cloud.aiplatform_v1.types.Model.ExportFormat.ExportableContent]): + Output only. The content of this Model that + may be exported. + """ + + class ExportableContent(proto.Enum): + r"""The Model content that can be exported.""" + EXPORTABLE_CONTENT_UNSPECIFIED = 0 + ARTIFACT = 1 + IMAGE = 2 + + id = proto.Field(proto.STRING, number=1) + + exportable_contents = proto.RepeatedField( + proto.ENUM, number=2, enum="Model.ExportFormat.ExportableContent", + ) + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + description = proto.Field(proto.STRING, number=3) + + predict_schemata = proto.Field(proto.MESSAGE, number=4, message="PredictSchemata",) + + metadata_schema_uri = proto.Field(proto.STRING, number=5) + + metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + + supported_export_formats = proto.RepeatedField( + proto.MESSAGE, number=20, message=ExportFormat, + ) + + training_pipeline = proto.Field(proto.STRING, number=7) + + container_spec = proto.Field(proto.MESSAGE, number=9, message="ModelContainerSpec",) + + artifact_uri = proto.Field(proto.STRING, number=26) + + supported_deployment_resources_types = proto.RepeatedField( + proto.ENUM, number=10, enum=DeploymentResourcesType, + ) + + supported_input_storage_formats = proto.RepeatedField(proto.STRING, number=11) + + supported_output_storage_formats = proto.RepeatedField(proto.STRING, number=12) + + create_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + + deployed_models = proto.RepeatedField( + proto.MESSAGE, number=15, message=deployed_model_ref.DeployedModelRef, + ) + + etag = proto.Field(proto.STRING, number=16) + + labels = proto.MapField(proto.STRING, proto.STRING, number=17) + + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + ) + + +class PredictSchemata(proto.Message): + r"""Contains the schemata used in Model's predictions and explanations + via + ``PredictionService.Predict``, + [PredictionService.Explain][] and + ``BatchPredictionJob``. + + Attributes: + instance_schema_uri (str): + Immutable. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + [ExplainRequest.instances][] and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + parameters_schema_uri (str): + Immutable. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + [ExplainRequest.parameters][] and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported, then it is set to + an empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + prediction_schema_uri (str): + Immutable. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + [ExplainResponse.explanations][], and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + """ + + instance_schema_uri = proto.Field(proto.STRING, number=1) + + parameters_schema_uri = proto.Field(proto.STRING, number=2) + + prediction_schema_uri = proto.Field(proto.STRING, number=3) + + +class ModelContainerSpec(proto.Message): + r"""Specification of a container for serving predictions. This message + is a subset of the Kubernetes Container v1 core + `specification `__. + + Attributes: + image_uri (str): + Required. Immutable. URI of the Docker image to be used as + the custom container for serving predictions. This URI must + identify an image in Artifact Registry or Container + Registry. Learn more about the container publishing + requirements, including permissions requirements for the AI + Platform Service Agent, + `here `__. + + The container image is ingested upon + ``ModelService.UploadModel``, + stored internally, and this original path is afterwards not + used. + + To learn about the requirements for the Docker image itself, + see `Custom container + requirements `__. + command (Sequence[str]): + Immutable. Specifies the command that runs when the + container starts. This overrides the container's + `ENTRYPOINT `__. + Specify this field as an array of executable and arguments, + similar to a Docker ``ENTRYPOINT``'s "exec" form, not its + "shell" form. + + If you do not specify this field, then the container's + ``ENTRYPOINT`` runs, in conjunction with the + ``args`` + field or the container's + ```CMD`` `__, + if either exists. If this field is not specified and the + container does not have an ``ENTRYPOINT``, then refer to the + Docker documentation about how ``CMD`` and ``ENTRYPOINT`` + `interact `__. + + If you specify this field, then you can also specify the + ``args`` field to provide additional arguments for this + command. However, if you specify this field, then the + container's ``CMD`` is ignored. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. + + In this field, you can reference environment variables `set + by AI + Platform `__ + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``command`` field of the Kubernetes + Containers `v1 core + API `__. + args (Sequence[str]): + Immutable. Specifies arguments for the command that runs + when the container starts. This overrides the container's + ```CMD`` `__. + Specify this field as an array of executable and arguments, + similar to a Docker ``CMD``'s "default parameters" form. + + If you don't specify this field but do specify the + ``command`` + field, then the command from the ``command`` field runs + without any additional arguments. See the `Kubernetes + documentation `__ about how + the ``command`` and ``args`` fields interact with a + container's ``ENTRYPOINT`` and ``CMD``. + + If you don't specify this field and don't specify the + ``command`` field, then the container's + ```ENTRYPOINT`` `__ + and ``CMD`` determine what runs based on their default + behavior. See the Docker documentation about how ``CMD`` and + ``ENTRYPOINT`` `interact `__. + + In this field, you can reference environment variables `set + by AI + Platform `__ + and environment variables set in the + ``env`` + field. You cannot reference environment variables set in the + Docker image. In order for environment variables to be + expanded, reference them by using the following syntax: + $(VARIABLE_NAME) Note that this differs from Bash variable + expansion, which does not use parentheses. If a variable + cannot be resolved, the reference in the input string is + used unchanged. To avoid variable expansion, you can escape + this syntax with ``$$``; for example: $$(VARIABLE_NAME) This + field corresponds to the ``args`` field of the Kubernetes + Containers `v1 core + API `__. + env (Sequence[google.cloud.aiplatform_v1.types.EnvVar]): + Immutable. List of environment variables to set in the + container. After the container starts running, code running + in the container can read these environment variables. + + Additionally, the + ``command`` + and + ``args`` + fields can reference these variables. Later entries in this + list can also reference earlier entries. For example, the + following example sets the variable ``VAR_2`` to have the + value ``foo bar``: + + .. code:: json + + [ + { + "name": "VAR_1", + "value": "foo" + }, + { + "name": "VAR_2", + "value": "$(VAR_1) bar" + } + ] + + If you switch the order of the variables in the example, + then the expansion does not occur. + + This field corresponds to the ``env`` field of the + Kubernetes Containers `v1 core + API `__. + ports (Sequence[google.cloud.aiplatform_v1.types.Port]): + Immutable. List of ports to expose from the container. AI + Platform sends any prediction requests that it receives to + the first port on this list. AI Platform also sends + `liveness and health + checks `__ to + this port. + + If you do not specify this field, it defaults to following + value: + + .. code:: json + + [ + { + "containerPort": 8080 + } + ] + + AI Platform does not use ports other than the first one + listed. This field corresponds to the ``ports`` field of the + Kubernetes Containers `v1 core + API `__. + predict_route (str): + Immutable. HTTP path on the container to send prediction + requests to. AI Platform forwards requests sent using + ``projects.locations.endpoints.predict`` + to this path on the container's IP address and port. AI + Platform then returns the container's response in the API + response. + + For example, if you set this field to ``/foo``, then when AI + Platform receives a prediction request, it forwards the + request body in a POST request to the ``/foo`` path on the + port of your container specified by the first value of this + ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + ```AIP_ENDPOINT_ID`` `__ + environment variable.) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + ```AIP_DEPLOYED_MODEL_ID`` environment + variable `__.) + health_route (str): + Immutable. HTTP path on the container to send health checks + to. AI Platform intermittently sends GET requests to this + path on the container's IP address and port to check that + the container is healthy. Read more about `health + checks `__. + + For example, if you set this field to ``/bar``, then AI + Platform intermittently sends a GET request to the ``/bar`` + path on the port of your container specified by the first + value of this ``ModelContainerSpec``'s + ``ports`` + field. + + If you don't specify this field, it defaults to the + following value when you [deploy this Model to an + Endpoint][google.cloud.aiplatform.v1.EndpointService.DeployModel]: + /v1/endpoints/ENDPOINT/deployedModels/DEPLOYED_MODEL:predict + The placeholders in this value are replaced as follows: + + - ENDPOINT: The last segment (following ``endpoints/``)of + the Endpoint.name][] field of the Endpoint where this + Model has been deployed. (AI Platform makes this value + available to your container code as the + ```AIP_ENDPOINT_ID`` `__ + environment variable.) + + - DEPLOYED_MODEL: + ``DeployedModel.id`` + of the ``DeployedModel``. (AI Platform makes this value + available to your container code as the + ```AIP_DEPLOYED_MODEL_ID`` `__ + environment variable.) + """ + + image_uri = proto.Field(proto.STRING, number=1) + + command = proto.RepeatedField(proto.STRING, number=2) + + args = proto.RepeatedField(proto.STRING, number=3) + + env = proto.RepeatedField(proto.MESSAGE, number=4, message=env_var.EnvVar,) + + ports = proto.RepeatedField(proto.MESSAGE, number=5, message="Port",) + + predict_route = proto.Field(proto.STRING, number=6) + + health_route = proto.Field(proto.STRING, number=7) + + +class Port(proto.Message): + r"""Represents a network port in a container. + + Attributes: + container_port (int): + The number of the port to expose on the pod's + IP address. Must be a valid port number, between + 1 and 65535 inclusive. + """ + + container_port = proto.Field(proto.INT32, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation.py b/google/cloud/aiplatform_v1/types/model_evaluation.py new file mode 100644 index 0000000000..f617f3d197 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/model_evaluation.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"ModelEvaluation",}, +) + + +class ModelEvaluation(proto.Message): + r"""A collection of metrics calculated by comparing Model's + predictions on all of the test data against annotations from the + test data. + + Attributes: + name (str): + Output only. The resource name of the + ModelEvaluation. + metrics_schema_uri (str): + Output only. Points to a YAML file stored on Google Cloud + Storage describing the + ``metrics`` + of this ModelEvaluation. The schema is defined as an OpenAPI + 3.0.2 `Schema + Object `__. + metrics (google.protobuf.struct_pb2.Value): + Output only. Evaluation metrics of the Model. The schema of + the metrics is stored in + ``metrics_schema_uri`` + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + ModelEvaluation was created. + slice_dimensions (Sequence[str]): + Output only. All possible + ``dimensions`` of + ModelEvaluationSlices. The dimensions can be used as the + filter of the + ``ModelService.ListModelEvaluationSlices`` + request, in the form of ``slice.dimension = ``. + """ + + name = proto.Field(proto.STRING, number=1) + + metrics_schema_uri = proto.Field(proto.STRING, number=2) + + metrics = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + + create_time = proto.Field(proto.MESSAGE, number=4, message=timestamp.Timestamp,) + + slice_dimensions = proto.RepeatedField(proto.STRING, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py new file mode 100644 index 0000000000..5653c3d2b6 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/model_evaluation_slice.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"ModelEvaluationSlice",}, +) + + +class ModelEvaluationSlice(proto.Message): + r"""A collection of metrics calculated by comparing Model's + predictions on a slice of the test data against ground truth + annotations. + + Attributes: + name (str): + Output only. The resource name of the + ModelEvaluationSlice. + slice_ (google.cloud.aiplatform_v1.types.ModelEvaluationSlice.Slice): + Output only. The slice of the test data that + is used to evaluate the Model. + metrics_schema_uri (str): + Output only. Points to a YAML file stored on Google Cloud + Storage describing the + ``metrics`` + of this ModelEvaluationSlice. The schema is defined as an + OpenAPI 3.0.2 `Schema + Object `__. + metrics (google.protobuf.struct_pb2.Value): + Output only. Sliced evaluation metrics of the Model. The + schema of the metrics is stored in + ``metrics_schema_uri`` + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Timestamp when this + ModelEvaluationSlice was created. + """ + + class Slice(proto.Message): + r"""Definition of a slice. + + Attributes: + dimension (str): + Output only. The dimension of the slice. Well-known + dimensions are: + + - ``annotationSpec``: This slice is on the test data that + has either ground truth or prediction with + ``AnnotationSpec.display_name`` + equals to + ``value``. + value (str): + Output only. The value of the dimension in + this slice. + """ + + dimension = proto.Field(proto.STRING, number=1) + + value = proto.Field(proto.STRING, number=2) + + name = proto.Field(proto.STRING, number=1) + + slice_ = proto.Field(proto.MESSAGE, number=2, message=Slice,) + + metrics_schema_uri = proto.Field(proto.STRING, number=3) + + metrics = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + + create_time = proto.Field(proto.MESSAGE, number=5, message=timestamp.Timestamp,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/model_service.py b/google/cloud/aiplatform_v1/types/model_service.py new file mode 100644 index 0000000000..454e014fd5 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/model_service.py @@ -0,0 +1,487 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import operation +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "UploadModelRequest", + "UploadModelOperationMetadata", + "UploadModelResponse", + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "UpdateModelRequest", + "DeleteModelRequest", + "ExportModelRequest", + "ExportModelOperationMetadata", + "ExportModelResponse", + "GetModelEvaluationRequest", + "ListModelEvaluationsRequest", + "ListModelEvaluationsResponse", + "GetModelEvaluationSliceRequest", + "ListModelEvaluationSlicesRequest", + "ListModelEvaluationSlicesResponse", + }, +) + + +class UploadModelRequest(proto.Message): + r"""Request message for + ``ModelService.UploadModel``. + + Attributes: + parent (str): + Required. The resource name of the Location into which to + upload the Model. Format: + ``projects/{project}/locations/{location}`` + model (google.cloud.aiplatform_v1.types.Model): + Required. The Model to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + model = proto.Field(proto.MESSAGE, number=2, message=gca_model.Model,) + + +class UploadModelOperationMetadata(proto.Message): + r"""Details of + ``ModelService.UploadModel`` + operation. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class UploadModelResponse(proto.Message): + r"""Response message of + ``ModelService.UploadModel`` + operation. + + Attributes: + model (str): + The name of the uploaded Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + """ + + model = proto.Field(proto.STRING, number=1) + + +class GetModelRequest(proto.Message): + r"""Request message for + ``ModelService.GetModel``. + + Attributes: + name (str): + Required. The name of the Model resource. Format: + ``projects/{project}/locations/{location}/models/{model}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListModelsRequest(proto.Message): + r"""Request message for + ``ModelService.ListModels``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + Models from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. + + - ``model`` supports = and !=. ``model`` represents the + Model ID, i.e. the last segment of the Model's [resource + name][google.cloud.aiplatform.v1.Model.name]. + - ``display_name`` supports = and != + - ``labels`` supports general map functions that is: + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. + + Some examples: + + - ``model=1234`` + - ``displayName="myDisplayName"`` + - ``labels.myKey="myValue"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListModelsResponse.next_page_token`` + of the previous + ``ModelService.ListModels`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + order_by (str): + A comma-separated list of fields to order by, sorted in + ascending order. Use "desc" after a field name for + descending. Supported fields: + + - ``display_name`` + - ``create_time`` + - ``update_time`` + + Example: ``display_name, create_time desc``. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + order_by = proto.Field(proto.STRING, number=6) + + +class ListModelsResponse(proto.Message): + r"""Response message for + ``ModelService.ListModels`` + + Attributes: + models (Sequence[google.cloud.aiplatform_v1.types.Model]): + List of Models in the requested page. + next_page_token (str): + A token to retrieve next page of results. Pass to + ``ListModelsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + models = proto.RepeatedField(proto.MESSAGE, number=1, message=gca_model.Model,) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class UpdateModelRequest(proto.Message): + r"""Request message for + ``ModelService.UpdateModel``. + + Attributes: + model (google.cloud.aiplatform_v1.types.Model): + Required. The Model which replaces the + resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. For the + ``FieldMask`` definition, see + `FieldMask `__. + """ + + model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) + + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class DeleteModelRequest(proto.Message): + r"""Request message for + ``ModelService.DeleteModel``. + + Attributes: + name (str): + Required. The name of the Model resource to be deleted. + Format: + ``projects/{project}/locations/{location}/models/{model}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ExportModelRequest(proto.Message): + r"""Request message for + ``ModelService.ExportModel``. + + Attributes: + name (str): + Required. The resource name of the Model to export. Format: + ``projects/{project}/locations/{location}/models/{model}`` + output_config (google.cloud.aiplatform_v1.types.ExportModelRequest.OutputConfig): + Required. The desired output location and + configuration. + """ + + class OutputConfig(proto.Message): + r"""Output configuration for the Model export. + + Attributes: + export_format_id (str): + The ID of the format in which the Model must be exported. + Each Model lists the [export formats it + supports][google.cloud.aiplatform.v1.Model.supported_export_formats]. + If no value is provided here, then the first from the list + of the Model's supported formats is used by default. + artifact_destination (google.cloud.aiplatform_v1.types.GcsDestination): + The Cloud Storage location where the Model artifact is to be + written to. Under the directory given as the destination a + new one with name + "``model-export--``", + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format, will be created. Inside, the Model and any of its + supporting files will be written. This field should only be + set when the ``exportableContent`` field of the + [Model.supported_export_formats] object contains + ``ARTIFACT``. + image_destination (google.cloud.aiplatform_v1.types.ContainerRegistryDestination): + The Google Container Registry or Artifact Registry uri where + the Model container image will be copied to. This field + should only be set when the ``exportableContent`` field of + the [Model.supported_export_formats] object contains + ``IMAGE``. + """ + + export_format_id = proto.Field(proto.STRING, number=1) + + artifact_destination = proto.Field( + proto.MESSAGE, number=3, message=io.GcsDestination, + ) + + image_destination = proto.Field( + proto.MESSAGE, number=4, message=io.ContainerRegistryDestination, + ) + + name = proto.Field(proto.STRING, number=1) + + output_config = proto.Field(proto.MESSAGE, number=2, message=OutputConfig,) + + +class ExportModelOperationMetadata(proto.Message): + r"""Details of + ``ModelService.ExportModel`` + operation. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + output_info (google.cloud.aiplatform_v1.types.ExportModelOperationMetadata.OutputInfo): + Output only. Information further describing + the output of this Model export. + """ + + class OutputInfo(proto.Message): + r"""Further describes the output of the ExportModel. Supplements + ``ExportModelRequest.OutputConfig``. + + Attributes: + artifact_output_uri (str): + Output only. If the Model artifact is being + exported to Google Cloud Storage this is the + full path of the directory created, into which + the Model files are being written to. + image_output_uri (str): + Output only. If the Model image is being + exported to Google Container Registry or + Artifact Registry this is the full path of the + image created. + """ + + artifact_output_uri = proto.Field(proto.STRING, number=2) + + image_output_uri = proto.Field(proto.STRING, number=3) + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + output_info = proto.Field(proto.MESSAGE, number=2, message=OutputInfo,) + + +class ExportModelResponse(proto.Message): + r"""Response message of + ``ModelService.ExportModel`` + operation. + """ + + +class GetModelEvaluationRequest(proto.Message): + r"""Request message for + ``ModelService.GetModelEvaluation``. + + Attributes: + name (str): + Required. The name of the ModelEvaluation resource. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListModelEvaluationsRequest(proto.Message): + r"""Request message for + ``ModelService.ListModelEvaluations``. + + Attributes: + parent (str): + Required. The resource name of the Model to list the + ModelEvaluations from. Format: + ``projects/{project}/locations/{location}/models/{model}`` + filter (str): + The standard list filter. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListModelEvaluationsResponse.next_page_token`` + of the previous + ``ModelService.ListModelEvaluations`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListModelEvaluationsResponse(proto.Message): + r"""Response message for + ``ModelService.ListModelEvaluations``. + + Attributes: + model_evaluations (Sequence[google.cloud.aiplatform_v1.types.ModelEvaluation]): + List of ModelEvaluations in the requested + page. + next_page_token (str): + A token to retrieve next page of results. Pass to + ``ListModelEvaluationsRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + model_evaluations = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation.ModelEvaluation, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class GetModelEvaluationSliceRequest(proto.Message): + r"""Request message for + ``ModelService.GetModelEvaluationSlice``. + + Attributes: + name (str): + Required. The name of the ModelEvaluationSlice resource. + Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListModelEvaluationSlicesRequest(proto.Message): + r"""Request message for + ``ModelService.ListModelEvaluationSlices``. + + Attributes: + parent (str): + Required. The resource name of the ModelEvaluation to list + the ModelEvaluationSlices from. Format: + + ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + filter (str): + The standard list filter. + + - ``slice.dimension`` - for =. + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListModelEvaluationSlicesResponse.next_page_token`` + of the previous + ``ModelService.ListModelEvaluationSlices`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListModelEvaluationSlicesResponse(proto.Message): + r"""Response message for + ``ModelService.ListModelEvaluationSlices``. + + Attributes: + model_evaluation_slices (Sequence[google.cloud.aiplatform_v1.types.ModelEvaluationSlice]): + List of ModelEvaluations in the requested + page. + next_page_token (str): + A token to retrieve next page of results. Pass to + ``ListModelEvaluationSlicesRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + model_evaluation_slices = proto.RepeatedField( + proto.MESSAGE, number=1, message=model_evaluation_slice.ModelEvaluationSlice, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/operation.py b/google/cloud/aiplatform_v1/types/operation.py new file mode 100644 index 0000000000..fe24030e79 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/operation.py @@ -0,0 +1,73 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={"GenericOperationMetadata", "DeleteOperationMetadata",}, +) + + +class GenericOperationMetadata(proto.Message): + r"""Generic Metadata shared by all operations. + + Attributes: + partial_failures (Sequence[google.rpc.status_pb2.Status]): + Output only. Partial failures encountered. + E.g. single files that couldn't be read. + This field should never exceed 20 entries. + Status details field will contain standard GCP + error details. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the operation was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the operation was + updated for the last time. If the operation has + finished (successfully or not), this is the + finish time. + """ + + partial_failures = proto.RepeatedField( + proto.MESSAGE, number=1, message=status.Status, + ) + + create_time = proto.Field(proto.MESSAGE, number=2, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=3, message=timestamp.Timestamp,) + + +class DeleteOperationMetadata(proto.Message): + r"""Details of operations that perform deletes of any entities. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The common part of the operation metadata. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message="GenericOperationMetadata", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/pipeline_service.py b/google/cloud/aiplatform_v1/types/pipeline_service.py new file mode 100644 index 0000000000..b2c6d5bbe3 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/pipeline_service.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateTrainingPipelineRequest", + "GetTrainingPipelineRequest", + "ListTrainingPipelinesRequest", + "ListTrainingPipelinesResponse", + "DeleteTrainingPipelineRequest", + "CancelTrainingPipelineRequest", + }, +) + + +class CreateTrainingPipelineRequest(proto.Message): + r"""Request message for + ``PipelineService.CreateTrainingPipeline``. + + Attributes: + parent (str): + Required. The resource name of the Location to create the + TrainingPipeline in. Format: + ``projects/{project}/locations/{location}`` + training_pipeline (google.cloud.aiplatform_v1.types.TrainingPipeline): + Required. The TrainingPipeline to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + training_pipeline = proto.Field( + proto.MESSAGE, number=2, message=gca_training_pipeline.TrainingPipeline, + ) + + +class GetTrainingPipelineRequest(proto.Message): + r"""Request message for + ``PipelineService.GetTrainingPipeline``. + + Attributes: + name (str): + Required. The name of the TrainingPipeline resource. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListTrainingPipelinesRequest(proto.Message): + r"""Request message for + ``PipelineService.ListTrainingPipelines``. + + Attributes: + parent (str): + Required. The resource name of the Location to list the + TrainingPipelines from. Format: + ``projects/{project}/locations/{location}`` + filter (str): + The standard list filter. Supported fields: + + - ``display_name`` supports = and !=. + + - ``state`` supports = and !=. + + Some examples of using the filter are: + + - ``state="PIPELINE_STATE_SUCCEEDED" AND display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_RUNNING" OR display_name="my_pipeline"`` + + - ``NOT display_name="my_pipeline"`` + + - ``state="PIPELINE_STATE_FAILED"`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained via + ``ListTrainingPipelinesResponse.next_page_token`` + of the previous + ``PipelineService.ListTrainingPipelines`` + call. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + """ + + parent = proto.Field(proto.STRING, number=1) + + filter = proto.Field(proto.STRING, number=2) + + page_size = proto.Field(proto.INT32, number=3) + + page_token = proto.Field(proto.STRING, number=4) + + read_mask = proto.Field(proto.MESSAGE, number=5, message=field_mask.FieldMask,) + + +class ListTrainingPipelinesResponse(proto.Message): + r"""Response message for + ``PipelineService.ListTrainingPipelines`` + + Attributes: + training_pipelines (Sequence[google.cloud.aiplatform_v1.types.TrainingPipeline]): + List of TrainingPipelines in the requested + page. + next_page_token (str): + A token to retrieve the next page of results. Pass to + ``ListTrainingPipelinesRequest.page_token`` + to obtain that page. + """ + + @property + def raw_page(self): + return self + + training_pipelines = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_training_pipeline.TrainingPipeline, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteTrainingPipelineRequest(proto.Message): + r"""Request message for + ``PipelineService.DeleteTrainingPipeline``. + + Attributes: + name (str): + Required. The name of the TrainingPipeline resource to be + deleted. Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +class CancelTrainingPipelineRequest(proto.Message): + r"""Request message for + ``PipelineService.CancelTrainingPipeline``. + + Attributes: + name (str): + Required. The name of the TrainingPipeline to cancel. + Format: + + ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + """ + + name = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/pipeline_state.py b/google/cloud/aiplatform_v1/types/pipeline_state.py new file mode 100644 index 0000000000..f6a885ae42 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/pipeline_state.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"PipelineState",}, +) + + +class PipelineState(proto.Enum): + r"""Describes the state of a pipeline.""" + PIPELINE_STATE_UNSPECIFIED = 0 + PIPELINE_STATE_QUEUED = 1 + PIPELINE_STATE_PENDING = 2 + PIPELINE_STATE_RUNNING = 3 + PIPELINE_STATE_SUCCEEDED = 4 + PIPELINE_STATE_FAILED = 5 + PIPELINE_STATE_CANCELLING = 6 + PIPELINE_STATE_CANCELLED = 7 + PIPELINE_STATE_PAUSED = 8 + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/prediction_service.py b/google/cloud/aiplatform_v1/types/prediction_service.py new file mode 100644 index 0000000000..21a01372f4 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/prediction_service.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={"PredictRequest", "PredictResponse",}, +) + + +class PredictRequest(proto.Message): + r"""Request message for + ``PredictionService.Predict``. + + Attributes: + endpoint (str): + Required. The name of the Endpoint requested to serve the + prediction. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + instances (Sequence[google.protobuf.struct_pb2.Value]): + Required. The instances that are the input to the prediction + call. A DeployedModel may have an upper limit on the number + of instances it supports per request, and when it is + exceeded the prediction call errors in case of AutoML + Models, or, in case of customer created Models, the + behaviour is as documented by that Model. The schema of any + single instance may be specified via Endpoint's + DeployedModels' + [Model's][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``instance_schema_uri``. + parameters (google.protobuf.struct_pb2.Value): + The parameters that govern the prediction. The schema of the + parameters may be specified via Endpoint's DeployedModels' + [Model's ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``parameters_schema_uri``. + """ + + endpoint = proto.Field(proto.STRING, number=1) + + instances = proto.RepeatedField(proto.MESSAGE, number=2, message=struct.Value,) + + parameters = proto.Field(proto.MESSAGE, number=3, message=struct.Value,) + + +class PredictResponse(proto.Message): + r"""Response message for + ``PredictionService.Predict``. + + Attributes: + predictions (Sequence[google.protobuf.struct_pb2.Value]): + The predictions that are the output of the predictions call. + The schema of any single prediction may be specified via + Endpoint's DeployedModels' [Model's + ][google.cloud.aiplatform.v1.DeployedModel.model] + [PredictSchemata's][google.cloud.aiplatform.v1.Model.predict_schemata] + ``prediction_schema_uri``. + deployed_model_id (str): + ID of the Endpoint's DeployedModel that + served this prediction. + """ + + predictions = proto.RepeatedField(proto.MESSAGE, number=1, message=struct.Value,) + + deployed_model_id = proto.Field(proto.STRING, number=2) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool.py b/google/cloud/aiplatform_v1/types/specialist_pool.py new file mode 100644 index 0000000000..6265316bd5 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/specialist_pool.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"SpecialistPool",}, +) + + +class SpecialistPool(proto.Message): + r"""SpecialistPool represents customers' own workforce to work on + their data labeling jobs. It includes a group of specialist + managers who are responsible for managing the labelers in this + pool as well as customers' data labeling jobs associated with + this pool. + Customers create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the jobs using + CrowdCompute console. + + Attributes: + name (str): + Required. The resource name of the + SpecialistPool. + display_name (str): + Required. The user-defined name of the + SpecialistPool. The name can be up to 128 + characters long and can be consist of any UTF-8 + characters. + This field should be unique on project-level. + specialist_managers_count (int): + Output only. The number of Specialists in + this SpecialistPool. + specialist_manager_emails (Sequence[str]): + The email addresses of the specialists in the + SpecialistPool. + pending_data_labeling_jobs (Sequence[str]): + Output only. The resource name of the pending + data labeling jobs. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + specialist_managers_count = proto.Field(proto.INT32, number=3) + + specialist_manager_emails = proto.RepeatedField(proto.STRING, number=4) + + pending_data_labeling_jobs = proto.RepeatedField(proto.STRING, number=5) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1/types/specialist_pool_service.py new file mode 100644 index 0000000000..69e49bb355 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/specialist_pool_service.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import operation +from google.cloud.aiplatform_v1.types import specialist_pool as gca_specialist_pool +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "CreateSpecialistPoolRequest", + "CreateSpecialistPoolOperationMetadata", + "GetSpecialistPoolRequest", + "ListSpecialistPoolsRequest", + "ListSpecialistPoolsResponse", + "DeleteSpecialistPoolRequest", + "UpdateSpecialistPoolRequest", + "UpdateSpecialistPoolOperationMetadata", + }, +) + + +class CreateSpecialistPoolRequest(proto.Message): + r"""Request message for + ``SpecialistPoolService.CreateSpecialistPool``. + + Attributes: + parent (str): + Required. The parent Project name for the new + SpecialistPool. The form is + ``projects/{project}/locations/{location}``. + specialist_pool (google.cloud.aiplatform_v1.types.SpecialistPool): + Required. The SpecialistPool to create. + """ + + parent = proto.Field(proto.STRING, number=1) + + specialist_pool = proto.Field( + proto.MESSAGE, number=2, message=gca_specialist_pool.SpecialistPool, + ) + + +class CreateSpecialistPoolOperationMetadata(proto.Message): + r"""Runtime operation information for + ``SpecialistPoolService.CreateSpecialistPool``. + + Attributes: + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + generic_metadata = proto.Field( + proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, + ) + + +class GetSpecialistPoolRequest(proto.Message): + r"""Request message for + ``SpecialistPoolService.GetSpecialistPool``. + + Attributes: + name (str): + Required. The name of the SpecialistPool resource. The form + is + + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + """ + + name = proto.Field(proto.STRING, number=1) + + +class ListSpecialistPoolsRequest(proto.Message): + r"""Request message for + ``SpecialistPoolService.ListSpecialistPools``. + + Attributes: + parent (str): + Required. The name of the SpecialistPool's parent resource. + Format: ``projects/{project}/locations/{location}`` + page_size (int): + The standard list page size. + page_token (str): + The standard list page token. Typically obtained by + ``ListSpecialistPoolsResponse.next_page_token`` + of the previous + ``SpecialistPoolService.ListSpecialistPools`` + call. Return first page if empty. + read_mask (google.protobuf.field_mask_pb2.FieldMask): + Mask specifying which fields to read. + FieldMask represents a set of + """ + + parent = proto.Field(proto.STRING, number=1) + + page_size = proto.Field(proto.INT32, number=2) + + page_token = proto.Field(proto.STRING, number=3) + + read_mask = proto.Field(proto.MESSAGE, number=4, message=field_mask.FieldMask,) + + +class ListSpecialistPoolsResponse(proto.Message): + r"""Response message for + ``SpecialistPoolService.ListSpecialistPools``. + + Attributes: + specialist_pools (Sequence[google.cloud.aiplatform_v1.types.SpecialistPool]): + A list of SpecialistPools that matches the + specified filter in the request. + next_page_token (str): + The standard List next-page token. + """ + + @property + def raw_page(self): + return self + + specialist_pools = proto.RepeatedField( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + ) + + next_page_token = proto.Field(proto.STRING, number=2) + + +class DeleteSpecialistPoolRequest(proto.Message): + r"""Request message for + ``SpecialistPoolService.DeleteSpecialistPool``. + + Attributes: + name (str): + Required. The resource name of the SpecialistPool to delete. + Format: + ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + force (bool): + If set to true, any specialist managers in + this SpecialistPool will also be deleted. + (Otherwise, the request will only work if the + SpecialistPool has no specialist managers.) + """ + + name = proto.Field(proto.STRING, number=1) + + force = proto.Field(proto.BOOL, number=2) + + +class UpdateSpecialistPoolRequest(proto.Message): + r"""Request message for + ``SpecialistPoolService.UpdateSpecialistPool``. + + Attributes: + specialist_pool (google.cloud.aiplatform_v1.types.SpecialistPool): + Required. The SpecialistPool which replaces + the resource on the server. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the + resource. + """ + + specialist_pool = proto.Field( + proto.MESSAGE, number=1, message=gca_specialist_pool.SpecialistPool, + ) + + update_mask = proto.Field(proto.MESSAGE, number=2, message=field_mask.FieldMask,) + + +class UpdateSpecialistPoolOperationMetadata(proto.Message): + r"""Runtime operation metadata for + ``SpecialistPoolService.UpdateSpecialistPool``. + + Attributes: + specialist_pool (str): + Output only. The name of the SpecialistPool to which the + specialists are being added. Format: + + ``projects/{project_id}/locations/{location_id}/specialistPools/{specialist_pool}`` + generic_metadata (google.cloud.aiplatform_v1.types.GenericOperationMetadata): + The operation generic information. + """ + + specialist_pool = proto.Field(proto.STRING, number=1) + + generic_metadata = proto.Field( + proto.MESSAGE, number=2, message=operation.GenericOperationMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/study.py b/google/cloud/aiplatform_v1/types/study.py new file mode 100644 index 0000000000..99a688f045 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/study.py @@ -0,0 +1,444 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={"Trial", "StudySpec", "Measurement",}, +) + + +class Trial(proto.Message): + r"""A message representing a Trial. A Trial contains a unique set + of Parameters that has been or will be evaluated, along with the + objective metrics got by running the Trial. + + Attributes: + id (str): + Output only. The identifier of the Trial + assigned by the service. + state (google.cloud.aiplatform_v1.types.Trial.State): + Output only. The detailed state of the Trial. + parameters (Sequence[google.cloud.aiplatform_v1.types.Trial.Parameter]): + Output only. The parameters of the Trial. + final_measurement (google.cloud.aiplatform_v1.types.Measurement): + Output only. The final measurement containing + the objective value. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the Trial was started. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the Trial's status changed to + ``SUCCEEDED`` or ``INFEASIBLE``. + custom_job (str): + Output only. The CustomJob name linked to the + Trial. It's set for a HyperparameterTuningJob's + Trial. + """ + + class State(proto.Enum): + r"""Describes a Trial state.""" + STATE_UNSPECIFIED = 0 + REQUESTED = 1 + ACTIVE = 2 + STOPPING = 3 + SUCCEEDED = 4 + INFEASIBLE = 5 + + class Parameter(proto.Message): + r"""A message representing a parameter to be tuned. + + Attributes: + parameter_id (str): + Output only. The ID of the parameter. The parameter should + be defined in [StudySpec's + Parameters][google.cloud.aiplatform.v1.StudySpec.parameters]. + value (google.protobuf.struct_pb2.Value): + Output only. The value of the parameter. ``number_value`` + will be set if a parameter defined in StudySpec is in type + 'INTEGER', 'DOUBLE' or 'DISCRETE'. ``string_value`` will be + set if a parameter defined in StudySpec is in type + 'CATEGORICAL'. + """ + + parameter_id = proto.Field(proto.STRING, number=1) + + value = proto.Field(proto.MESSAGE, number=2, message=struct.Value,) + + id = proto.Field(proto.STRING, number=2) + + state = proto.Field(proto.ENUM, number=3, enum=State,) + + parameters = proto.RepeatedField(proto.MESSAGE, number=4, message=Parameter,) + + final_measurement = proto.Field(proto.MESSAGE, number=5, message="Measurement",) + + start_time = proto.Field(proto.MESSAGE, number=7, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=8, message=timestamp.Timestamp,) + + custom_job = proto.Field(proto.STRING, number=11) + + +class StudySpec(proto.Message): + r"""Represents specification of a Study. + + Attributes: + metrics (Sequence[google.cloud.aiplatform_v1.types.StudySpec.MetricSpec]): + Required. Metric specs for the Study. + parameters (Sequence[google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec]): + Required. The set of parameters to tune. + algorithm (google.cloud.aiplatform_v1.types.StudySpec.Algorithm): + The search algorithm specified for the Study. + observation_noise (google.cloud.aiplatform_v1.types.StudySpec.ObservationNoise): + The observation noise level of the study. + Currently only supported by the Vizier service. + Not supported by HyperparamterTuningJob or + TrainingPipeline. + measurement_selection_type (google.cloud.aiplatform_v1.types.StudySpec.MeasurementSelectionType): + Describe which measurement selection type + will be used + """ + + class Algorithm(proto.Enum): + r"""The available search algorithms for the Study.""" + ALGORITHM_UNSPECIFIED = 0 + GRID_SEARCH = 2 + RANDOM_SEARCH = 3 + + class ObservationNoise(proto.Enum): + r"""Describes the noise level of the repeated observations. + "Noisy" means that the repeated observations with the same Trial + parameters may lead to different metric evaluations. + """ + OBSERVATION_NOISE_UNSPECIFIED = 0 + LOW = 1 + HIGH = 2 + + class MeasurementSelectionType(proto.Enum): + r"""This indicates which measurement to use if/when the service + automatically selects the final measurement from previously reported + intermediate measurements. Choose this based on two considerations: + A) Do you expect your measurements to monotonically improve? If so, + choose LAST_MEASUREMENT. On the other hand, if you're in a situation + where your system can "over-train" and you expect the performance to + get better for a while but then start declining, choose + BEST_MEASUREMENT. B) Are your measurements significantly noisy + and/or irreproducible? If so, BEST_MEASUREMENT will tend to be + over-optimistic, and it may be better to choose LAST_MEASUREMENT. If + both or neither of (A) and (B) apply, it doesn't matter which + selection type is chosen. + """ + MEASUREMENT_SELECTION_TYPE_UNSPECIFIED = 0 + LAST_MEASUREMENT = 1 + BEST_MEASUREMENT = 2 + + class MetricSpec(proto.Message): + r"""Represents a metric to optimize. + + Attributes: + metric_id (str): + Required. The ID of the metric. Must not + contain whitespaces and must be unique amongst + all MetricSpecs. + goal (google.cloud.aiplatform_v1.types.StudySpec.MetricSpec.GoalType): + Required. The optimization goal of the + metric. + """ + + class GoalType(proto.Enum): + r"""The available types of optimization goals.""" + GOAL_TYPE_UNSPECIFIED = 0 + MAXIMIZE = 1 + MINIMIZE = 2 + + metric_id = proto.Field(proto.STRING, number=1) + + goal = proto.Field(proto.ENUM, number=2, enum="StudySpec.MetricSpec.GoalType",) + + class ParameterSpec(proto.Message): + r"""Represents a single parameter to optimize. + + Attributes: + double_value_spec (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.DoubleValueSpec): + The value spec for a 'DOUBLE' parameter. + integer_value_spec (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.IntegerValueSpec): + The value spec for an 'INTEGER' parameter. + categorical_value_spec (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.CategoricalValueSpec): + The value spec for a 'CATEGORICAL' parameter. + discrete_value_spec (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.DiscreteValueSpec): + The value spec for a 'DISCRETE' parameter. + parameter_id (str): + Required. The ID of the parameter. Must not + contain whitespaces and must be unique amongst + all ParameterSpecs. + scale_type (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.ScaleType): + How the parameter should be scaled. Leave unset for + ``CATEGORICAL`` parameters. + conditional_parameter_specs (Sequence[google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.ConditionalParameterSpec]): + A conditional parameter node is active if the parameter's + value matches the conditional node's parent_value_condition. + + If two items in conditional_parameter_specs have the same + name, they must have disjoint parent_value_condition. + """ + + class ScaleType(proto.Enum): + r"""The type of scaling that should be applied to this parameter.""" + SCALE_TYPE_UNSPECIFIED = 0 + UNIT_LINEAR_SCALE = 1 + UNIT_LOG_SCALE = 2 + UNIT_REVERSE_LOG_SCALE = 3 + + class DoubleValueSpec(proto.Message): + r"""Value specification for a parameter in ``DOUBLE`` type. + + Attributes: + min_value (float): + Required. Inclusive minimum value of the + parameter. + max_value (float): + Required. Inclusive maximum value of the + parameter. + """ + + min_value = proto.Field(proto.DOUBLE, number=1) + + max_value = proto.Field(proto.DOUBLE, number=2) + + class IntegerValueSpec(proto.Message): + r"""Value specification for a parameter in ``INTEGER`` type. + + Attributes: + min_value (int): + Required. Inclusive minimum value of the + parameter. + max_value (int): + Required. Inclusive maximum value of the + parameter. + """ + + min_value = proto.Field(proto.INT64, number=1) + + max_value = proto.Field(proto.INT64, number=2) + + class CategoricalValueSpec(proto.Message): + r"""Value specification for a parameter in ``CATEGORICAL`` type. + + Attributes: + values (Sequence[str]): + Required. The list of possible categories. + """ + + values = proto.RepeatedField(proto.STRING, number=1) + + class DiscreteValueSpec(proto.Message): + r"""Value specification for a parameter in ``DISCRETE`` type. + + Attributes: + values (Sequence[float]): + Required. A list of possible values. + The list should be in increasing order and at + least 1e-10 apart. For instance, this parameter + might have possible settings of 1.5, 2.5, and + 4.0. This list should not contain more than + 1,000 values. + """ + + values = proto.RepeatedField(proto.DOUBLE, number=1) + + class ConditionalParameterSpec(proto.Message): + r"""Represents a parameter spec with condition from its parent + parameter. + + Attributes: + parent_discrete_values (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition): + The spec for matching values from a parent parameter of + ``DISCRETE`` type. + parent_int_values (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition): + The spec for matching values from a parent parameter of + ``INTEGER`` type. + parent_categorical_values (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition): + The spec for matching values from a parent parameter of + ``CATEGORICAL`` type. + parameter_spec (google.cloud.aiplatform_v1.types.StudySpec.ParameterSpec): + Required. The spec for a conditional + parameter. + """ + + class DiscreteValueCondition(proto.Message): + r"""Represents the spec to match discrete values from parent + parameter. + + Attributes: + values (Sequence[float]): + Required. Matches values of the parent parameter of + 'DISCRETE' type. All values must exist in + ``discrete_value_spec`` of parent parameter. + + The Epsilon of the value matching is 1e-10. + """ + + values = proto.RepeatedField(proto.DOUBLE, number=1) + + class IntValueCondition(proto.Message): + r"""Represents the spec to match integer values from parent + parameter. + + Attributes: + values (Sequence[int]): + Required. Matches values of the parent parameter of + 'INTEGER' type. All values must lie in + ``integer_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.INT64, number=1) + + class CategoricalValueCondition(proto.Message): + r"""Represents the spec to match categorical values from parent + parameter. + + Attributes: + values (Sequence[str]): + Required. Matches values of the parent parameter of + 'CATEGORICAL' type. All values must exist in + ``categorical_value_spec`` of parent parameter. + """ + + values = proto.RepeatedField(proto.STRING, number=1) + + parent_discrete_values = proto.Field( + proto.MESSAGE, + number=2, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition", + ) + + parent_int_values = proto.Field( + proto.MESSAGE, + number=3, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition", + ) + + parent_categorical_values = proto.Field( + proto.MESSAGE, + number=4, + oneof="parent_value_condition", + message="StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition", + ) + + parameter_spec = proto.Field( + proto.MESSAGE, number=1, message="StudySpec.ParameterSpec", + ) + + double_value_spec = proto.Field( + proto.MESSAGE, + number=2, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DoubleValueSpec", + ) + + integer_value_spec = proto.Field( + proto.MESSAGE, + number=3, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.IntegerValueSpec", + ) + + categorical_value_spec = proto.Field( + proto.MESSAGE, + number=4, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.CategoricalValueSpec", + ) + + discrete_value_spec = proto.Field( + proto.MESSAGE, + number=5, + oneof="parameter_value_spec", + message="StudySpec.ParameterSpec.DiscreteValueSpec", + ) + + parameter_id = proto.Field(proto.STRING, number=1) + + scale_type = proto.Field( + proto.ENUM, number=6, enum="StudySpec.ParameterSpec.ScaleType", + ) + + conditional_parameter_specs = proto.RepeatedField( + proto.MESSAGE, + number=10, + message="StudySpec.ParameterSpec.ConditionalParameterSpec", + ) + + metrics = proto.RepeatedField(proto.MESSAGE, number=1, message=MetricSpec,) + + parameters = proto.RepeatedField(proto.MESSAGE, number=2, message=ParameterSpec,) + + algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) + + observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) + + measurement_selection_type = proto.Field( + proto.ENUM, number=7, enum=MeasurementSelectionType, + ) + + +class Measurement(proto.Message): + r"""A message representing a Measurement of a Trial. A + Measurement contains the Metrics got by executing a Trial using + suggested hyperparameter values. + + Attributes: + step_count (int): + Output only. The number of steps the machine + learning model has been trained for. Must be + non-negative. + metrics (Sequence[google.cloud.aiplatform_v1.types.Measurement.Metric]): + Output only. A list of metrics got by + evaluating the objective functions using + suggested Parameter values. + """ + + class Metric(proto.Message): + r"""A message representing a metric in the measurement. + + Attributes: + metric_id (str): + Output only. The ID of the Metric. The Metric should be + defined in [StudySpec's + Metrics][google.cloud.aiplatform.v1.StudySpec.metrics]. + value (float): + Output only. The value for this metric. + """ + + metric_id = proto.Field(proto.STRING, number=1) + + value = proto.Field(proto.DOUBLE, number=2) + + step_count = proto.Field(proto.INT64, number=2) + + metrics = proto.RepeatedField(proto.MESSAGE, number=3, message=Metric,) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/training_pipeline.py b/google/cloud/aiplatform_v1/types/training_pipeline.py new file mode 100644 index 0000000000..9a41f231a5 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/training_pipeline.py @@ -0,0 +1,467 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import pipeline_state +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", + manifest={ + "TrainingPipeline", + "InputDataConfig", + "FractionSplit", + "FilterSplit", + "PredefinedSplit", + "TimestampSplit", + }, +) + + +class TrainingPipeline(proto.Message): + r"""The TrainingPipeline orchestrates tasks associated with training a + Model. It always executes the training task, and optionally may also + export data from AI Platform's Dataset which becomes the training + input, ``upload`` + the Model to AI Platform, and evaluate the Model. + + Attributes: + name (str): + Output only. Resource name of the + TrainingPipeline. + display_name (str): + Required. The user-defined name of this + TrainingPipeline. + input_data_config (google.cloud.aiplatform_v1.types.InputDataConfig): + Specifies AI Platform owned input data that may be used for + training the Model. The TrainingPipeline's + ``training_task_definition`` + should make clear whether this config is used and if there + are any special requirements on how it should be filled. If + nothing about this config is mentioned in the + ``training_task_definition``, + then it should be assumed that the TrainingPipeline does not + depend on this configuration. + training_task_definition (str): + Required. A Google Cloud Storage path to the + YAML file that defines the training task which + is responsible for producing the model artifact, + and may also include additional auxiliary work. + The definition files that can be used here are + found in gs://google-cloud- + aiplatform/schema/trainingjob/definition/. Note: + The URI given on output will be immutable and + probably different, including the URI scheme, + than the one given on input. The output URI will + point to a location where the user only has a + read access. + training_task_inputs (google.protobuf.struct_pb2.Value): + Required. The training task's parameter(s), as specified in + the + ``training_task_definition``'s + ``inputs``. + training_task_metadata (google.protobuf.struct_pb2.Value): + Output only. The metadata information as specified in the + ``training_task_definition``'s + ``metadata``. This metadata is an auxiliary runtime and + final information about the training task. While the + pipeline is running this information is populated only at a + best effort basis. Only present if the pipeline's + ``training_task_definition`` + contains ``metadata`` object. + model_to_upload (google.cloud.aiplatform_v1.types.Model): + Describes the Model that may be uploaded (via + ``ModelService.UploadModel``) + by this TrainingPipeline. The TrainingPipeline's + ``training_task_definition`` + should make clear whether this Model description should be + populated, and if there are any special requirements + regarding how it should be filled. If nothing is mentioned + in the + ``training_task_definition``, + then it should be assumed that this field should not be + filled and the training task either uploads the Model + without a need of this information, or that training task + does not support uploading a Model as part of the pipeline. + When the Pipeline's state becomes + ``PIPELINE_STATE_SUCCEEDED`` and the trained Model had been + uploaded into AI Platform, then the model_to_upload's + resource ``name`` is + populated. The Model is always uploaded into the Project and + Location in which this pipeline is. + state (google.cloud.aiplatform_v1.types.PipelineState): + Output only. The detailed state of the + pipeline. + error (google.rpc.status_pb2.Status): + Output only. Only populated when the pipeline's state is + ``PIPELINE_STATE_FAILED`` or ``PIPELINE_STATE_CANCELLED``. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the TrainingPipeline + was created. + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the TrainingPipeline for the first + time entered the ``PIPELINE_STATE_RUNNING`` state. + end_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the TrainingPipeline entered any of + the following states: ``PIPELINE_STATE_SUCCEEDED``, + ``PIPELINE_STATE_FAILED``, ``PIPELINE_STATE_CANCELLED``. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Time when the TrainingPipeline + was most recently updated. + labels (Sequence[google.cloud.aiplatform_v1.types.TrainingPipeline.LabelsEntry]): + The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Customer-managed encryption key spec for a TrainingPipeline. + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if + ``model_to_upload`` + is not set separately. + """ + + name = proto.Field(proto.STRING, number=1) + + display_name = proto.Field(proto.STRING, number=2) + + input_data_config = proto.Field(proto.MESSAGE, number=3, message="InputDataConfig",) + + training_task_definition = proto.Field(proto.STRING, number=4) + + training_task_inputs = proto.Field(proto.MESSAGE, number=5, message=struct.Value,) + + training_task_metadata = proto.Field(proto.MESSAGE, number=6, message=struct.Value,) + + model_to_upload = proto.Field(proto.MESSAGE, number=7, message=model.Model,) + + state = proto.Field(proto.ENUM, number=9, enum=pipeline_state.PipelineState,) + + error = proto.Field(proto.MESSAGE, number=10, message=status.Status,) + + create_time = proto.Field(proto.MESSAGE, number=11, message=timestamp.Timestamp,) + + start_time = proto.Field(proto.MESSAGE, number=12, message=timestamp.Timestamp,) + + end_time = proto.Field(proto.MESSAGE, number=13, message=timestamp.Timestamp,) + + update_time = proto.Field(proto.MESSAGE, number=14, message=timestamp.Timestamp,) + + labels = proto.MapField(proto.STRING, proto.STRING, number=15) + + encryption_spec = proto.Field( + proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, + ) + + +class InputDataConfig(proto.Message): + r"""Specifies AI Platform owned input data to be used for + training, and possibly evaluating, the Model. + + Attributes: + fraction_split (google.cloud.aiplatform_v1.types.FractionSplit): + Split based on fractions defining the size of + each set. + filter_split (google.cloud.aiplatform_v1.types.FilterSplit): + Split based on the provided filters for each + set. + predefined_split (google.cloud.aiplatform_v1.types.PredefinedSplit): + Supported only for tabular Datasets. + Split based on a predefined key. + timestamp_split (google.cloud.aiplatform_v1.types.TimestampSplit): + Supported only for tabular Datasets. + Split based on the timestamp of the input data + pieces. + gcs_destination (google.cloud.aiplatform_v1.types.GcsDestination): + The Cloud Storage location where the training data is to be + written to. In the given directory a new directory is + created with name: + ``dataset---`` + where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 + format. All training input data is written into that + directory. + + The AI Platform environment variables representing Cloud + Storage data URIs are represented in the Cloud Storage + wildcard format to support sharded data. e.g.: + "gs://.../training-*.jsonl" + + - AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for + tabular data + - AIP_TRAINING_DATA_URI = + + "gcs_destination/dataset---/training-*.${AIP_DATA_FORMAT}" + + - AIP_VALIDATION_DATA_URI = + + "gcs_destination/dataset---/validation-*.${AIP_DATA_FORMAT}" + + - AIP_TEST_DATA_URI = + + "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". + bigquery_destination (google.cloud.aiplatform_v1.types.BigQueryDestination): + Only applicable to custom training with tabular Dataset with + BigQuery source. + + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data is written into that dataset. In the + dataset three tables are created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .training" + + - AIP_VALIDATION_DATA_URI = + + "bigquery_destination.dataset\_\ **\ .validation" + + - AIP_TEST_DATA_URI = + "bigquery_destination.dataset\_\ **\ .test". + dataset_id (str): + Required. The ID of the Dataset in the same Project and + Location which data will be used to train the Model. The + Dataset must use schema compatible with Model being trained, + and what is compatible should be described in the used + TrainingPipeline's [training_task_definition] + [google.cloud.aiplatform.v1.TrainingPipeline.training_task_definition]. + For tabular Datasets, all their data is exported to + training, to pick and choose from. + annotations_filter (str): + Applicable only to Datasets that have DataItems and + Annotations. + + A filter on Annotations of the Dataset. Only Annotations + that both match this filter and belong to DataItems not + ignored by the split method are used in respectively + training, validation or test role, depending on the role of + the DataItem they are on (for the auto-assigned that role is + decided by AI Platform). A filter with same syntax as the + one used in + ``ListAnnotations`` + may be used, but note here it filters across all Annotations + of the Dataset, and not just within a single DataItem. + annotation_schema_uri (str): + Applicable only to custom training with Datasets that have + DataItems and Annotations. + + Cloud Storage URI that points to a YAML file describing the + annotation schema. The schema is defined as an OpenAPI 3.0.2 + `Schema + Object `__. The + schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/ , + note that the chosen schema must be consistent with + ``metadata`` + of the Dataset specified by + ``dataset_id``. + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + """ + + fraction_split = proto.Field( + proto.MESSAGE, number=2, oneof="split", message="FractionSplit", + ) + + filter_split = proto.Field( + proto.MESSAGE, number=3, oneof="split", message="FilterSplit", + ) + + predefined_split = proto.Field( + proto.MESSAGE, number=4, oneof="split", message="PredefinedSplit", + ) + + timestamp_split = proto.Field( + proto.MESSAGE, number=5, oneof="split", message="TimestampSplit", + ) + + gcs_destination = proto.Field( + proto.MESSAGE, number=8, oneof="destination", message=io.GcsDestination, + ) + + bigquery_destination = proto.Field( + proto.MESSAGE, number=10, oneof="destination", message=io.BigQueryDestination, + ) + + dataset_id = proto.Field(proto.STRING, number=1) + + annotations_filter = proto.Field(proto.STRING, number=6) + + annotation_schema_uri = proto.Field(proto.STRING, number=9) + + +class FractionSplit(proto.Message): + r"""Assigns the input data to training, validation, and test sets as per + the given fractions. Any of ``training_fraction``, + ``validation_fraction`` and ``test_fraction`` may optionally be + provided, they must sum to up to 1. If the provided ones sum to less + than 1, the remainder is assigned to sets as decided by AI Platform. + If none of the fractions are set, by default roughly 80% of data is + used for training, 10% for validation, and 10% for test. + + Attributes: + training_fraction (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction (float): + The fraction of the input data that is to be + used to evaluate the Model. + """ + + training_fraction = proto.Field(proto.DOUBLE, number=1) + + validation_fraction = proto.Field(proto.DOUBLE, number=2) + + test_fraction = proto.Field(proto.DOUBLE, number=3) + + +class FilterSplit(proto.Message): + r"""Assigns input data to training, validation, and test sets + based on the given filters, data pieces not matched by any + filter are ignored. Currently only supported for Datasets + containing DataItems. + If any of the filters in this message are to match nothing, then + they can be set as '-' (the minus sign). + + Supported only for unstructured Datasets. + + Attributes: + training_filter (str): + Required. A filter on DataItems of the Dataset. DataItems + that match this filter are used to train the Model. A filter + with same syntax as the one used in + ``DatasetService.ListDataItems`` + may be used. If a single DataItem is matched by more than + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. + validation_filter (str): + Required. A filter on DataItems of the Dataset. DataItems + that match this filter are used to validate the Model. A + filter with same syntax as the one used in + ``DatasetService.ListDataItems`` + may be used. If a single DataItem is matched by more than + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. + test_filter (str): + Required. A filter on DataItems of the Dataset. DataItems + that match this filter are used to test the Model. A filter + with same syntax as the one used in + ``DatasetService.ListDataItems`` + may be used. If a single DataItem is matched by more than + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. + """ + + training_filter = proto.Field(proto.STRING, number=1) + + validation_filter = proto.Field(proto.STRING, number=2) + + test_filter = proto.Field(proto.STRING, number=3) + + +class PredefinedSplit(proto.Message): + r"""Assigns input data to training, validation, and test sets + based on the value of a provided key. + + Supported only for tabular Datasets. + + Attributes: + key (str): + Required. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + """ + + key = proto.Field(proto.STRING, number=1) + + +class TimestampSplit(proto.Message): + r"""Assigns input data to training, validation, and test sets + based on a provided timestamps. The youngest data pieces are + assigned to training set, next to validation set, and the oldest + to the test set. + Supported only for tabular Datasets. + + Attributes: + training_fraction (float): + The fraction of the input data that is to be + used to train the Model. + validation_fraction (float): + The fraction of the input data that is to be + used to validate the Model. + test_fraction (float): + The fraction of the input data that is to be + used to evaluate the Model. + key (str): + Required. The key is a name of one of the Dataset's data + columns. The values of the key (the values in the column) + must be in RFC 3339 ``date-time`` format, where + ``time-offset`` = ``"Z"`` (e.g. 1985-04-12T23:20:50.52Z). If + for a piece of data the key is not present or has an invalid + value, that piece is ignored by the pipeline. + """ + + training_fraction = proto.Field(proto.DOUBLE, number=1) + + validation_fraction = proto.Field(proto.DOUBLE, number=2) + + test_fraction = proto.Field(proto.DOUBLE, number=3) + + key = proto.Field(proto.STRING, number=4) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/user_action_reference.py b/google/cloud/aiplatform_v1/types/user_action_reference.py new file mode 100644 index 0000000000..da59ac6ac6 --- /dev/null +++ b/google/cloud/aiplatform_v1/types/user_action_reference.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1", manifest={"UserActionReference",}, +) + + +class UserActionReference(proto.Message): + r"""References an API call. It contains more information about + long running operation and Jobs that are triggered by the API + call. + + Attributes: + operation (str): + For API calls that return a long running + operation. Resource name of the long running + operation. Format: + 'projects/{project}/locations/{location}/operations/{operation}' + data_labeling_job (str): + For API calls that start a LabelingJob. Resource name of the + LabelingJob. Format: + + 'projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}' + method (str): + The method name of the API call. For example, + "/google.cloud.aiplatform.v1alpha1.DatasetService.CreateDataset". + """ + + operation = proto.Field(proto.STRING, number=1, oneof="reference") + + data_labeling_job = proto.Field(proto.STRING, number=2, oneof="reference") + + method = proto.Field(proto.STRING, number=3) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index 5f466b2e9b..b76824eac3 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -61,6 +61,7 @@ from .types.dataset_service import ListDatasetsResponse from .types.dataset_service import UpdateDatasetRequest from .types.deployed_model_ref import DeployedModelRef +from .types.encryption_spec import EncryptionSpec from .types.endpoint import DeployedModel from .types.endpoint import Endpoint from .types.endpoint_service import CreateEndpointOperationMetadata @@ -79,8 +80,10 @@ from .types.env_var import EnvVar from .types.explanation import Attribution from .types.explanation import Explanation +from .types.explanation import ExplanationMetadataOverride from .types.explanation import ExplanationParameters from .types.explanation import ExplanationSpec +from .types.explanation import ExplanationSpecOverride from .types.explanation import FeatureNoiseSigma from .types.explanation import IntegratedGradientsAttribution from .types.explanation import ModelExplanation @@ -247,6 +250,7 @@ "DeployedModel", "DeployedModelRef", "DiskSpec", + "EncryptionSpec", "Endpoint", "EndpointServiceClient", "EnvVar", @@ -254,8 +258,10 @@ "ExplainResponse", "Explanation", "ExplanationMetadata", + "ExplanationMetadataOverride", "ExplanationParameters", "ExplanationSpec", + "ExplanationSpecOverride", "ExportDataConfig", "ExportDataOperationMetadata", "ExportDataRequest", diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py index 1927709f30..2915b4888d 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/async_client.py @@ -37,6 +37,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore @@ -96,6 +97,7 @@ class DatasetServiceAsyncClient: DatasetServiceClient.parse_common_location_path ) + from_service_account_info = DatasetServiceClient.from_service_account_info from_service_account_file = DatasetServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -173,17 +175,18 @@ async def create_dataset( r"""Creates a Dataset. Args: - request (:class:`~.dataset_service.CreateDatasetRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateDatasetRequest`): The request object. Request message for ``DatasetService.CreateDataset``. parent (:class:`str`): Required. The resource name of the Location to create the Dataset in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - dataset (:class:`~.gca_dataset.Dataset`): + dataset (:class:`google.cloud.aiplatform_v1beta1.types.Dataset`): Required. The Dataset to create. This corresponds to the ``dataset`` field on the ``request`` instance; if ``request`` is provided, this @@ -196,12 +199,12 @@ async def create_dataset( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.gca_dataset.Dataset`: A collection of - DataItems and Annotations on them. + :class:`google.cloud.aiplatform_v1beta1.types.Dataset` A + collection of DataItems and Annotations on them. """ # Create or coerce a protobuf request object. @@ -264,12 +267,13 @@ async def get_dataset( r"""Gets a Dataset. Args: - request (:class:`~.dataset_service.GetDatasetRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetDatasetRequest`): The request object. Request message for ``DatasetService.GetDataset``. name (:class:`str`): Required. The name of the Dataset resource. + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -281,7 +285,7 @@ async def get_dataset( sent along with the request as metadata. Returns: - ~.dataset.Dataset: + google.cloud.aiplatform_v1beta1.types.Dataset: A collection of DataItems and Annotations on them. @@ -337,25 +341,26 @@ async def update_dataset( r"""Updates a Dataset. Args: - request (:class:`~.dataset_service.UpdateDatasetRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateDatasetRequest`): The request object. Request message for ``DatasetService.UpdateDataset``. - dataset (:class:`~.gca_dataset.Dataset`): + dataset (:class:`google.cloud.aiplatform_v1beta1.types.Dataset`): Required. The Dataset which replaces the resource on the server. + This corresponds to the ``dataset`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see - - [FieldMask](https://tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). + `FieldMask `__. Updatable fields: - ``display_name`` - ``description`` - ``labels`` + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -367,7 +372,7 @@ async def update_dataset( sent along with the request as metadata. Returns: - ~.gca_dataset.Dataset: + google.cloud.aiplatform_v1beta1.types.Dataset: A collection of DataItems and Annotations on them. @@ -426,12 +431,13 @@ async def list_datasets( r"""Lists Datasets in a Location. Args: - request (:class:`~.dataset_service.ListDatasetsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListDatasetsRequest`): The request object. Request message for ``DatasetService.ListDatasets``. parent (:class:`str`): Required. The name of the Dataset's parent resource. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -443,7 +449,7 @@ async def list_datasets( sent along with the request as metadata. Returns: - ~.pagers.ListDatasetsAsyncPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListDatasetsAsyncPager: Response message for ``DatasetService.ListDatasets``. @@ -507,13 +513,14 @@ async def delete_dataset( r"""Deletes a Dataset. Args: - request (:class:`~.dataset_service.DeleteDatasetRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteDatasetRequest`): The request object. Request message for ``DatasetService.DeleteDataset``. name (:class:`str`): Required. The resource name of the Dataset to delete. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -525,24 +532,22 @@ async def delete_dataset( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -604,19 +609,21 @@ async def import_data( r"""Imports data into a Dataset. Args: - request (:class:`~.dataset_service.ImportDataRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ImportDataRequest`): The request object. Request message for ``DatasetService.ImportData``. name (:class:`str`): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - import_configs (:class:`Sequence[~.dataset.ImportDataConfig]`): + import_configs (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.ImportDataConfig]`): Required. The desired input locations. The contents of all input locations will be imported in one batch. + This corresponds to the ``import_configs`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -628,11 +635,11 @@ async def import_data( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.dataset_service.ImportDataResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.ImportDataResponse` Response message for ``DatasetService.ImportData``. @@ -699,18 +706,20 @@ async def export_data( r"""Exports data from a Dataset. Args: - request (:class:`~.dataset_service.ExportDataRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ExportDataRequest`): The request object. Request message for ``DatasetService.ExportData``. name (:class:`str`): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - export_config (:class:`~.dataset.ExportDataConfig`): + export_config (:class:`google.cloud.aiplatform_v1beta1.types.ExportDataConfig`): Required. The desired output location. + This corresponds to the ``export_config`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -722,11 +731,11 @@ async def export_data( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.dataset_service.ExportDataResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.ExportDataResponse` Response message for ``DatasetService.ExportData``. @@ -791,13 +800,14 @@ async def list_data_items( r"""Lists DataItems in a Dataset. Args: - request (:class:`~.dataset_service.ListDataItemsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListDataItemsRequest`): The request object. Request message for ``DatasetService.ListDataItems``. parent (:class:`str`): Required. The resource name of the Dataset to list DataItems from. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -809,7 +819,7 @@ async def list_data_items( sent along with the request as metadata. Returns: - ~.pagers.ListDataItemsAsyncPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListDataItemsAsyncPager: Response message for ``DatasetService.ListDataItems``. @@ -873,7 +883,7 @@ async def get_annotation_spec( r"""Gets an AnnotationSpec. Args: - request (:class:`~.dataset_service.GetAnnotationSpecRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetAnnotationSpecRequest`): The request object. Request message for ``DatasetService.GetAnnotationSpec``. name (:class:`str`): @@ -881,6 +891,7 @@ async def get_annotation_spec( Format: ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -892,7 +903,7 @@ async def get_annotation_spec( sent along with the request as metadata. Returns: - ~.annotation_spec.AnnotationSpec: + google.cloud.aiplatform_v1beta1.types.AnnotationSpec: Identifies a concept with which DataItems may be annotated with. @@ -947,7 +958,7 @@ async def list_annotations( r"""Lists Annotations belongs to a dataitem Args: - request (:class:`~.dataset_service.ListAnnotationsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListAnnotationsRequest`): The request object. Request message for ``DatasetService.ListAnnotations``. parent (:class:`str`): @@ -955,6 +966,7 @@ async def list_annotations( Annotations from. Format: ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -966,7 +978,7 @@ async def list_annotations( sent along with the request as metadata. Returns: - ~.pagers.ListAnnotationsAsyncPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListAnnotationsAsyncPager: Response message for ``DatasetService.ListAnnotations``. diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py index 1e63153291..187b9c8f0c 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/client.py @@ -41,6 +41,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.protobuf import empty_pb2 as empty # type: ignore from google.protobuf import field_mask_pb2 as field_mask # type: ignore @@ -122,6 +123,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DatasetServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -134,7 +151,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + DatasetServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -303,10 +320,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.DatasetServiceTransport]): The + transport (Union[str, DatasetServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -342,21 +359,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -399,7 +412,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -417,17 +430,18 @@ def create_dataset( r"""Creates a Dataset. Args: - request (:class:`~.dataset_service.CreateDatasetRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateDatasetRequest): The request object. Request message for ``DatasetService.CreateDataset``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the Dataset in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - dataset (:class:`~.gca_dataset.Dataset`): + dataset (google.cloud.aiplatform_v1beta1.types.Dataset): Required. The Dataset to create. This corresponds to the ``dataset`` field on the ``request`` instance; if ``request`` is provided, this @@ -440,12 +454,12 @@ def create_dataset( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.gca_dataset.Dataset`: A collection of - DataItems and Annotations on them. + :class:`google.cloud.aiplatform_v1beta1.types.Dataset` A + collection of DataItems and Annotations on them. """ # Create or coerce a protobuf request object. @@ -509,12 +523,13 @@ def get_dataset( r"""Gets a Dataset. Args: - request (:class:`~.dataset_service.GetDatasetRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetDatasetRequest): The request object. Request message for ``DatasetService.GetDataset``. - name (:class:`str`): + name (str): Required. The name of the Dataset resource. + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -526,7 +541,7 @@ def get_dataset( sent along with the request as metadata. Returns: - ~.dataset.Dataset: + google.cloud.aiplatform_v1beta1.types.Dataset: A collection of DataItems and Annotations on them. @@ -583,25 +598,26 @@ def update_dataset( r"""Updates a Dataset. Args: - request (:class:`~.dataset_service.UpdateDatasetRequest`): + request (google.cloud.aiplatform_v1beta1.types.UpdateDatasetRequest): The request object. Request message for ``DatasetService.UpdateDataset``. - dataset (:class:`~.gca_dataset.Dataset`): + dataset (google.cloud.aiplatform_v1beta1.types.Dataset): Required. The Dataset which replaces the resource on the server. + This corresponds to the ``dataset`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see - - [FieldMask](https://tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). + `FieldMask `__. Updatable fields: - ``display_name`` - ``description`` - ``labels`` + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -613,7 +629,7 @@ def update_dataset( sent along with the request as metadata. Returns: - ~.gca_dataset.Dataset: + google.cloud.aiplatform_v1beta1.types.Dataset: A collection of DataItems and Annotations on them. @@ -673,12 +689,13 @@ def list_datasets( r"""Lists Datasets in a Location. Args: - request (:class:`~.dataset_service.ListDatasetsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDatasetsRequest): The request object. Request message for ``DatasetService.ListDatasets``. - parent (:class:`str`): + parent (str): Required. The name of the Dataset's parent resource. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -690,7 +707,7 @@ def list_datasets( sent along with the request as metadata. Returns: - ~.pagers.ListDatasetsPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListDatasetsPager: Response message for ``DatasetService.ListDatasets``. @@ -755,13 +772,14 @@ def delete_dataset( r"""Deletes a Dataset. Args: - request (:class:`~.dataset_service.DeleteDatasetRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteDatasetRequest): The request object. Request message for ``DatasetService.DeleteDataset``. - name (:class:`str`): + name (str): Required. The resource name of the Dataset to delete. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -773,24 +791,22 @@ def delete_dataset( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -853,19 +869,21 @@ def import_data( r"""Imports data into a Dataset. Args: - request (:class:`~.dataset_service.ImportDataRequest`): + request (google.cloud.aiplatform_v1beta1.types.ImportDataRequest): The request object. Request message for ``DatasetService.ImportData``. - name (:class:`str`): + name (str): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - import_configs (:class:`Sequence[~.dataset.ImportDataConfig]`): + import_configs (Sequence[google.cloud.aiplatform_v1beta1.types.ImportDataConfig]): Required. The desired input locations. The contents of all input locations will be imported in one batch. + This corresponds to the ``import_configs`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -877,11 +895,11 @@ def import_data( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.dataset_service.ImportDataResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.ImportDataResponse` Response message for ``DatasetService.ImportData``. @@ -949,18 +967,20 @@ def export_data( r"""Exports data from a Dataset. Args: - request (:class:`~.dataset_service.ExportDataRequest`): + request (google.cloud.aiplatform_v1beta1.types.ExportDataRequest): The request object. Request message for ``DatasetService.ExportData``. - name (:class:`str`): + name (str): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - export_config (:class:`~.dataset.ExportDataConfig`): + export_config (google.cloud.aiplatform_v1beta1.types.ExportDataConfig): Required. The desired output location. + This corresponds to the ``export_config`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -972,11 +992,11 @@ def export_data( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.dataset_service.ExportDataResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.ExportDataResponse` Response message for ``DatasetService.ExportData``. @@ -1042,13 +1062,14 @@ def list_data_items( r"""Lists DataItems in a Dataset. Args: - request (:class:`~.dataset_service.ListDataItemsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataItemsRequest): The request object. Request message for ``DatasetService.ListDataItems``. - parent (:class:`str`): + parent (str): Required. The resource name of the Dataset to list DataItems from. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1060,7 +1081,7 @@ def list_data_items( sent along with the request as metadata. Returns: - ~.pagers.ListDataItemsPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListDataItemsPager: Response message for ``DatasetService.ListDataItems``. @@ -1125,14 +1146,15 @@ def get_annotation_spec( r"""Gets an AnnotationSpec. Args: - request (:class:`~.dataset_service.GetAnnotationSpecRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetAnnotationSpecRequest): The request object. Request message for ``DatasetService.GetAnnotationSpec``. - name (:class:`str`): + name (str): Required. The name of the AnnotationSpec resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1144,7 +1166,7 @@ def get_annotation_spec( sent along with the request as metadata. Returns: - ~.annotation_spec.AnnotationSpec: + google.cloud.aiplatform_v1beta1.types.AnnotationSpec: Identifies a concept with which DataItems may be annotated with. @@ -1200,14 +1222,15 @@ def list_annotations( r"""Lists Annotations belongs to a dataitem Args: - request (:class:`~.dataset_service.ListAnnotationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListAnnotationsRequest): The request object. Request message for ``DatasetService.ListAnnotations``. - parent (:class:`str`): + parent (str): Required. The resource name of the DataItem to list Annotations from. Format: ``projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1219,7 +1242,7 @@ def list_annotations( sent along with the request as metadata. Returns: - ~.pagers.ListAnnotationsPager: + google.cloud.aiplatform_v1beta1.services.dataset_service.pagers.ListAnnotationsPager: Response message for ``DatasetService.ListAnnotations``. diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py index 43c3156caf..4c5d248571 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/pagers.py @@ -27,7 +27,7 @@ class ListDatasetsPager: """A pager for iterating through ``list_datasets`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListDatasetsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse` object, and provides an ``__iter__`` method to iterate through its ``datasets`` field. @@ -36,7 +36,7 @@ class ListDatasetsPager: through the ``datasets`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListDatasetsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -54,9 +54,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListDatasetsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDatasetsRequest): The initial request object. - response (:class:`~.dataset_service.ListDatasetsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -89,7 +89,7 @@ class ListDatasetsAsyncPager: """A pager for iterating through ``list_datasets`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListDatasetsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse` object, and provides an ``__aiter__`` method to iterate through its ``datasets`` field. @@ -98,7 +98,7 @@ class ListDatasetsAsyncPager: through the ``datasets`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListDatasetsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -116,9 +116,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListDatasetsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDatasetsRequest): The initial request object. - response (:class:`~.dataset_service.ListDatasetsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDatasetsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -155,7 +155,7 @@ class ListDataItemsPager: """A pager for iterating through ``list_data_items`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListDataItemsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse` object, and provides an ``__iter__`` method to iterate through its ``data_items`` field. @@ -164,7 +164,7 @@ class ListDataItemsPager: through the ``data_items`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListDataItemsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -182,9 +182,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListDataItemsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataItemsRequest): The initial request object. - response (:class:`~.dataset_service.ListDataItemsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -217,7 +217,7 @@ class ListDataItemsAsyncPager: """A pager for iterating through ``list_data_items`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListDataItemsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse` object, and provides an ``__aiter__`` method to iterate through its ``data_items`` field. @@ -226,7 +226,7 @@ class ListDataItemsAsyncPager: through the ``data_items`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListDataItemsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -244,9 +244,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListDataItemsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataItemsRequest): The initial request object. - response (:class:`~.dataset_service.ListDataItemsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDataItemsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -283,7 +283,7 @@ class ListAnnotationsPager: """A pager for iterating through ``list_annotations`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListAnnotationsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse` object, and provides an ``__iter__`` method to iterate through its ``annotations`` field. @@ -292,7 +292,7 @@ class ListAnnotationsPager: through the ``annotations`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListAnnotationsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -310,9 +310,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListAnnotationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListAnnotationsRequest): The initial request object. - response (:class:`~.dataset_service.ListAnnotationsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -345,7 +345,7 @@ class ListAnnotationsAsyncPager: """A pager for iterating through ``list_annotations`` requests. This class thinly wraps an initial - :class:`~.dataset_service.ListAnnotationsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse` object, and provides an ``__aiter__`` method to iterate through its ``annotations`` field. @@ -354,7 +354,7 @@ class ListAnnotationsAsyncPager: through the ``annotations`` field on the corresponding responses. - All the usual :class:`~.dataset_service.ListAnnotationsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -372,9 +372,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.dataset_service.ListAnnotationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListAnnotationsRequest): The initial request object. - response (:class:`~.dataset_service.ListAnnotationsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListAnnotationsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py index f8496b801c..a4461d2ced 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = DatasetServiceGrpcTransport _transport_registry["grpc_asyncio"] = DatasetServiceGrpcAsyncIOTransport - __all__ = ( "DatasetServiceTransport", "DatasetServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py index 583e9864cd..56f567959a 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/base.py @@ -74,10 +74,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py index 7120c2eb9a..b4fd90ee1f 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc.py @@ -60,6 +60,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -90,6 +91,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -106,6 +111,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -115,11 +125,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -149,6 +154,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -159,17 +168,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -193,7 +213,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -228,7 +248,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -239,13 +260,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_dataset( diff --git a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py index aff766aa24..0c38b2ec38 100644 --- a/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/dataset_service/transports/grpc_asyncio.py @@ -104,6 +104,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -135,12 +136,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -151,6 +156,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -160,11 +170,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -194,6 +199,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -204,14 +213,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -225,6 +244,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -244,13 +264,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_dataset( diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py index 9c6af3bd16..43242b1148 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/async_client.py @@ -31,6 +31,7 @@ from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -86,6 +87,7 @@ class EndpointServiceAsyncClient: EndpointServiceClient.parse_common_location_path ) + from_service_account_info = EndpointServiceClient.from_service_account_info from_service_account_file = EndpointServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -163,17 +165,18 @@ async def create_endpoint( r"""Creates an Endpoint. Args: - request (:class:`~.endpoint_service.CreateEndpointRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateEndpointRequest`): The request object. Request message for ``EndpointService.CreateEndpoint``. parent (:class:`str`): Required. The resource name of the Location to create the Endpoint in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - endpoint (:class:`~.gca_endpoint.Endpoint`): + endpoint (:class:`google.cloud.aiplatform_v1beta1.types.Endpoint`): Required. The Endpoint to create. This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this @@ -186,13 +189,11 @@ async def create_endpoint( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_endpoint.Endpoint`: Models are deployed - into it, and afterwards Endpoint is called to obtain - predictions and explanations. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Endpoint` Models are deployed into it, and afterwards Endpoint is called to obtain + predictions and explanations. """ # Create or coerce a protobuf request object. @@ -255,12 +256,13 @@ async def get_endpoint( r"""Gets an Endpoint. Args: - request (:class:`~.endpoint_service.GetEndpointRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetEndpointRequest`): The request object. Request message for ``EndpointService.GetEndpoint`` name (:class:`str`): Required. The name of the Endpoint resource. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -272,7 +274,7 @@ async def get_endpoint( sent along with the request as metadata. Returns: - ~.endpoint.Endpoint: + google.cloud.aiplatform_v1beta1.types.Endpoint: Models are deployed into it, and afterwards Endpoint is called to obtain predictions and explanations. @@ -328,13 +330,14 @@ async def list_endpoints( r"""Lists Endpoints in a Location. Args: - request (:class:`~.endpoint_service.ListEndpointsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListEndpointsRequest`): The request object. Request message for ``EndpointService.ListEndpoints``. parent (:class:`str`): Required. The resource name of the Location from which to list the Endpoints. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -346,7 +349,7 @@ async def list_endpoints( sent along with the request as metadata. Returns: - ~.pagers.ListEndpointsAsyncPager: + google.cloud.aiplatform_v1beta1.services.endpoint_service.pagers.ListEndpointsAsyncPager: Response message for ``EndpointService.ListEndpoints``. @@ -411,18 +414,20 @@ async def update_endpoint( r"""Updates an Endpoint. Args: - request (:class:`~.endpoint_service.UpdateEndpointRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateEndpointRequest`): The request object. Request message for ``EndpointService.UpdateEndpoint``. - endpoint (:class:`~.gca_endpoint.Endpoint`): + endpoint (:class:`google.cloud.aiplatform_v1beta1.types.Endpoint`): Required. The Endpoint which replaces the resource on the server. + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to - the resource. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The update mask applies to the resource. See + `FieldMask `__. + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -434,7 +439,7 @@ async def update_endpoint( sent along with the request as metadata. Returns: - ~.gca_endpoint.Endpoint: + google.cloud.aiplatform_v1beta1.types.Endpoint: Models are deployed into it, and afterwards Endpoint is called to obtain predictions and explanations. @@ -494,13 +499,14 @@ async def delete_endpoint( r"""Deletes an Endpoint. Args: - request (:class:`~.endpoint_service.DeleteEndpointRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteEndpointRequest`): The request object. Request message for ``EndpointService.DeleteEndpoint``. name (:class:`str`): Required. The name of the Endpoint resource to be deleted. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -512,24 +518,22 @@ async def delete_endpoint( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -595,27 +599,29 @@ async def deploy_model( DeployedModel within it. Args: - request (:class:`~.endpoint_service.DeployModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeployModelRequest`): The request object. Request message for ``EndpointService.DeployModel``. endpoint (:class:`str`): Required. The name of the Endpoint resource into which to deploy a Model. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - deployed_model (:class:`~.gca_endpoint.DeployedModel`): + deployed_model (:class:`google.cloud.aiplatform_v1beta1.types.DeployedModel`): Required. The DeployedModel to be created within the Endpoint. Note that ``Endpoint.traffic_split`` must be updated for the DeployedModel to start receiving traffic, either as part of this call, or via ``EndpointService.UpdateEndpoint``. + This corresponds to the ``deployed_model`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]`): + traffic_split (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.DeployModelRequest.TrafficSplitEntry]`): A map from a DeployedModel's ID to the percentage of this Endpoint's traffic that should be forwarded to that DeployedModel. @@ -631,6 +637,7 @@ async def deploy_model( If this field is empty, then the Endpoint's ``traffic_split`` is not updated. + This corresponds to the ``traffic_split`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -642,11 +649,11 @@ async def deploy_model( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.endpoint_service.DeployModelResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.DeployModelResponse` Response message for ``EndpointService.DeployModel``. @@ -720,23 +727,25 @@ async def undeploy_model( using. Args: - request (:class:`~.endpoint_service.UndeployModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UndeployModelRequest`): The request object. Request message for ``EndpointService.UndeployModel``. endpoint (:class:`str`): Required. The name of the Endpoint resource from which to undeploy a Model. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. deployed_model_id (:class:`str`): Required. The ID of the DeployedModel to be undeployed from the Endpoint. + This corresponds to the ``deployed_model_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]`): + traffic_split (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.UndeployModelRequest.TrafficSplitEntry]`): If this field is provided, then the Endpoint's ``traffic_split`` will be overwritten with it. If last DeployedModel is @@ -746,6 +755,7 @@ async def undeploy_model( undeployed only if it doesn't have any traffic assigned to it when this method executes, or if this field unassigns any traffic to it. + This corresponds to the ``traffic_split`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -757,11 +767,11 @@ async def undeploy_model( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.endpoint_service.UndeployModelResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.UndeployModelResponse` Response message for ``EndpointService.UndeployModel``. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py index 5ea003b827..35f968b7c7 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py @@ -35,6 +35,7 @@ from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -118,6 +119,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + EndpointServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -130,7 +147,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + EndpointServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -254,10 +271,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.EndpointServiceTransport]): The + transport (Union[str, EndpointServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -293,21 +310,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -350,7 +363,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -368,17 +381,18 @@ def create_endpoint( r"""Creates an Endpoint. Args: - request (:class:`~.endpoint_service.CreateEndpointRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateEndpointRequest): The request object. Request message for ``EndpointService.CreateEndpoint``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the Endpoint in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - endpoint (:class:`~.gca_endpoint.Endpoint`): + endpoint (google.cloud.aiplatform_v1beta1.types.Endpoint): Required. The Endpoint to create. This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this @@ -391,13 +405,11 @@ def create_endpoint( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_endpoint.Endpoint`: Models are deployed - into it, and afterwards Endpoint is called to obtain - predictions and explanations. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.Endpoint` Models are deployed into it, and afterwards Endpoint is called to obtain + predictions and explanations. """ # Create or coerce a protobuf request object. @@ -461,12 +473,13 @@ def get_endpoint( r"""Gets an Endpoint. Args: - request (:class:`~.endpoint_service.GetEndpointRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetEndpointRequest): The request object. Request message for ``EndpointService.GetEndpoint`` - name (:class:`str`): + name (str): Required. The name of the Endpoint resource. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -478,7 +491,7 @@ def get_endpoint( sent along with the request as metadata. Returns: - ~.endpoint.Endpoint: + google.cloud.aiplatform_v1beta1.types.Endpoint: Models are deployed into it, and afterwards Endpoint is called to obtain predictions and explanations. @@ -535,13 +548,14 @@ def list_endpoints( r"""Lists Endpoints in a Location. Args: - request (:class:`~.endpoint_service.ListEndpointsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListEndpointsRequest): The request object. Request message for ``EndpointService.ListEndpoints``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location from which to list the Endpoints. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -553,7 +567,7 @@ def list_endpoints( sent along with the request as metadata. Returns: - ~.pagers.ListEndpointsPager: + google.cloud.aiplatform_v1beta1.services.endpoint_service.pagers.ListEndpointsPager: Response message for ``EndpointService.ListEndpoints``. @@ -619,18 +633,20 @@ def update_endpoint( r"""Updates an Endpoint. Args: - request (:class:`~.endpoint_service.UpdateEndpointRequest`): + request (google.cloud.aiplatform_v1beta1.types.UpdateEndpointRequest): The request object. Request message for ``EndpointService.UpdateEndpoint``. - endpoint (:class:`~.gca_endpoint.Endpoint`): + endpoint (google.cloud.aiplatform_v1beta1.types.Endpoint): Required. The Endpoint which replaces the resource on the server. + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): - Required. The update mask applies to - the resource. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -642,7 +658,7 @@ def update_endpoint( sent along with the request as metadata. Returns: - ~.gca_endpoint.Endpoint: + google.cloud.aiplatform_v1beta1.types.Endpoint: Models are deployed into it, and afterwards Endpoint is called to obtain predictions and explanations. @@ -703,13 +719,14 @@ def delete_endpoint( r"""Deletes an Endpoint. Args: - request (:class:`~.endpoint_service.DeleteEndpointRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteEndpointRequest): The request object. Request message for ``EndpointService.DeleteEndpoint``. - name (:class:`str`): + name (str): Required. The name of the Endpoint resource to be deleted. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -721,24 +738,22 @@ def delete_endpoint( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -805,27 +820,29 @@ def deploy_model( DeployedModel within it. Args: - request (:class:`~.endpoint_service.DeployModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeployModelRequest): The request object. Request message for ``EndpointService.DeployModel``. - endpoint (:class:`str`): + endpoint (str): Required. The name of the Endpoint resource into which to deploy a Model. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - deployed_model (:class:`~.gca_endpoint.DeployedModel`): + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): Required. The DeployedModel to be created within the Endpoint. Note that ``Endpoint.traffic_split`` must be updated for the DeployedModel to start receiving traffic, either as part of this call, or via ``EndpointService.UpdateEndpoint``. + This corresponds to the ``deployed_model`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]`): + traffic_split (Sequence[google.cloud.aiplatform_v1beta1.types.DeployModelRequest.TrafficSplitEntry]): A map from a DeployedModel's ID to the percentage of this Endpoint's traffic that should be forwarded to that DeployedModel. @@ -841,6 +858,7 @@ def deploy_model( If this field is empty, then the Endpoint's ``traffic_split`` is not updated. + This corresponds to the ``traffic_split`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -852,11 +870,11 @@ def deploy_model( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.endpoint_service.DeployModelResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.DeployModelResponse` Response message for ``EndpointService.DeployModel``. @@ -931,23 +949,25 @@ def undeploy_model( using. Args: - request (:class:`~.endpoint_service.UndeployModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.UndeployModelRequest): The request object. Request message for ``EndpointService.UndeployModel``. - endpoint (:class:`str`): + endpoint (str): Required. The name of the Endpoint resource from which to undeploy a Model. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - deployed_model_id (:class:`str`): + deployed_model_id (str): Required. The ID of the DeployedModel to be undeployed from the Endpoint. + This corresponds to the ``deployed_model_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - traffic_split (:class:`Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]`): + traffic_split (Sequence[google.cloud.aiplatform_v1beta1.types.UndeployModelRequest.TrafficSplitEntry]): If this field is provided, then the Endpoint's ``traffic_split`` will be overwritten with it. If last DeployedModel is @@ -957,6 +977,7 @@ def undeploy_model( undeployed only if it doesn't have any traffic assigned to it when this method executes, or if this field unassigns any traffic to it. + This corresponds to the ``traffic_split`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -968,11 +989,11 @@ def undeploy_model( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.endpoint_service.UndeployModelResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.UndeployModelResponse` Response message for ``EndpointService.UndeployModel``. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py index 86320c2178..1ceb718df1 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/pagers.py @@ -25,7 +25,7 @@ class ListEndpointsPager: """A pager for iterating through ``list_endpoints`` requests. This class thinly wraps an initial - :class:`~.endpoint_service.ListEndpointsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse` object, and provides an ``__iter__`` method to iterate through its ``endpoints`` field. @@ -34,7 +34,7 @@ class ListEndpointsPager: through the ``endpoints`` field on the corresponding responses. - All the usual :class:`~.endpoint_service.ListEndpointsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -52,9 +52,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.endpoint_service.ListEndpointsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListEndpointsRequest): The initial request object. - response (:class:`~.endpoint_service.ListEndpointsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -87,7 +87,7 @@ class ListEndpointsAsyncPager: """A pager for iterating through ``list_endpoints`` requests. This class thinly wraps an initial - :class:`~.endpoint_service.ListEndpointsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse` object, and provides an ``__aiter__`` method to iterate through its ``endpoints`` field. @@ -96,7 +96,7 @@ class ListEndpointsAsyncPager: through the ``endpoints`` field on the corresponding responses. - All the usual :class:`~.endpoint_service.ListEndpointsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -114,9 +114,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.endpoint_service.ListEndpointsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListEndpointsRequest): The initial request object. - response (:class:`~.endpoint_service.ListEndpointsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListEndpointsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py index 70a87e920e..3d0695461d 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = EndpointServiceGrpcTransport _transport_registry["grpc_asyncio"] = EndpointServiceGrpcAsyncIOTransport - __all__ = ( "EndpointServiceTransport", "EndpointServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py index 88b2b17c57..e55589de8f 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/base.py @@ -73,10 +73,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py index 5a2dee4d5a..e5b820de61 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc.py @@ -59,6 +59,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -89,6 +90,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -105,6 +110,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -114,11 +124,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -148,6 +153,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -158,17 +167,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -192,7 +212,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -227,7 +247,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -238,13 +259,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_endpoint( diff --git a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py index 661c63e9b9..a00971a72e 100644 --- a/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/endpoint_service/transports/grpc_asyncio.py @@ -103,6 +103,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -134,12 +135,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -150,6 +155,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -159,11 +169,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -193,6 +198,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -203,14 +212,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -224,6 +243,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -243,13 +263,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_endpoint( diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py index 2a24748d11..4c22267c01 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/async_client.py @@ -42,6 +42,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import ( @@ -92,6 +93,8 @@ class JobServiceAsyncClient: ) model_path = staticmethod(JobServiceClient.model_path) parse_model_path = staticmethod(JobServiceClient.parse_model_path) + trial_path = staticmethod(JobServiceClient.trial_path) + parse_trial_path = staticmethod(JobServiceClient.parse_trial_path) common_billing_account_path = staticmethod( JobServiceClient.common_billing_account_path @@ -116,6 +119,7 @@ class JobServiceAsyncClient: JobServiceClient.parse_common_location_path ) + from_service_account_info = JobServiceClient.from_service_account_info from_service_account_file = JobServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -194,17 +198,18 @@ async def create_custom_job( will be attempted to be run. Args: - request (:class:`~.job_service.CreateCustomJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateCustomJobRequest`): The request object. Request message for ``JobService.CreateCustomJob``. parent (:class:`str`): Required. The resource name of the Location to create the CustomJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - custom_job (:class:`~.gca_custom_job.CustomJob`): + custom_job (:class:`google.cloud.aiplatform_v1beta1.types.CustomJob`): Required. The CustomJob to create. This corresponds to the ``custom_job`` field on the ``request`` instance; if ``request`` is provided, this @@ -217,7 +222,7 @@ async def create_custom_job( sent along with the request as metadata. Returns: - ~.gca_custom_job.CustomJob: + google.cloud.aiplatform_v1beta1.types.CustomJob: Represents a job that runs custom workloads such as a Docker container or a Python package. A CustomJob can have @@ -280,12 +285,13 @@ async def get_custom_job( r"""Gets a CustomJob. Args: - request (:class:`~.job_service.GetCustomJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetCustomJobRequest`): The request object. Request message for ``JobService.GetCustomJob``. name (:class:`str`): Required. The name of the CustomJob resource. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -297,7 +303,7 @@ async def get_custom_job( sent along with the request as metadata. Returns: - ~.custom_job.CustomJob: + google.cloud.aiplatform_v1beta1.types.CustomJob: Represents a job that runs custom workloads such as a Docker container or a Python package. A CustomJob can have @@ -358,13 +364,14 @@ async def list_custom_jobs( r"""Lists CustomJobs in a Location. Args: - request (:class:`~.job_service.ListCustomJobsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListCustomJobsRequest`): The request object. Request message for ``JobService.ListCustomJobs``. parent (:class:`str`): Required. The resource name of the Location to list the CustomJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -376,7 +383,7 @@ async def list_custom_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListCustomJobsAsyncPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListCustomJobsAsyncPager: Response message for ``JobService.ListCustomJobs`` @@ -440,13 +447,14 @@ async def delete_custom_job( r"""Deletes a CustomJob. Args: - request (:class:`~.job_service.DeleteCustomJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteCustomJobRequest`): The request object. Request message for ``JobService.DeleteCustomJob``. name (:class:`str`): Required. The name of the CustomJob resource to be deleted. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -458,24 +466,22 @@ async def delete_custom_job( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -548,12 +554,13 @@ async def cancel_custom_job( is set to ``CANCELLED``. Args: - request (:class:`~.job_service.CancelCustomJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelCustomJobRequest`): The request object. Request message for ``JobService.CancelCustomJob``. name (:class:`str`): Required. The name of the CustomJob to cancel. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -614,18 +621,20 @@ async def create_data_labeling_job( r"""Creates a DataLabelingJob. Args: - request (:class:`~.job_service.CreateDataLabelingJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateDataLabelingJobRequest`): The request object. Request message for [DataLabelingJobService.CreateDataLabelingJob][]. parent (:class:`str`): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - data_labeling_job (:class:`~.gca_data_labeling_job.DataLabelingJob`): + data_labeling_job (:class:`google.cloud.aiplatform_v1beta1.types.DataLabelingJob`): Required. The DataLabelingJob to create. + This corresponds to the ``data_labeling_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -637,7 +646,7 @@ async def create_data_labeling_job( sent along with the request as metadata. Returns: - ~.gca_data_labeling_job.DataLabelingJob: + google.cloud.aiplatform_v1beta1.types.DataLabelingJob: DataLabelingJob is used to trigger a human labeling job on unlabeled data from the following Dataset: @@ -695,13 +704,14 @@ async def get_data_labeling_job( r"""Gets a DataLabelingJob. Args: - request (:class:`~.job_service.GetDataLabelingJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetDataLabelingJobRequest`): The request object. Request message for [DataLabelingJobService.GetDataLabelingJob][]. name (:class:`str`): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -713,7 +723,7 @@ async def get_data_labeling_job( sent along with the request as metadata. Returns: - ~.data_labeling_job.DataLabelingJob: + google.cloud.aiplatform_v1beta1.types.DataLabelingJob: DataLabelingJob is used to trigger a human labeling job on unlabeled data from the following Dataset: @@ -769,12 +779,13 @@ async def list_data_labeling_jobs( r"""Lists DataLabelingJobs in a Location. Args: - request (:class:`~.job_service.ListDataLabelingJobsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest`): The request object. Request message for [DataLabelingJobService.ListDataLabelingJobs][]. parent (:class:`str`): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -786,7 +797,7 @@ async def list_data_labeling_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListDataLabelingJobsAsyncPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListDataLabelingJobsAsyncPager: Response message for ``JobService.ListDataLabelingJobs``. @@ -850,7 +861,7 @@ async def delete_data_labeling_job( r"""Deletes a DataLabelingJob. Args: - request (:class:`~.job_service.DeleteDataLabelingJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteDataLabelingJobRequest`): The request object. Request message for ``JobService.DeleteDataLabelingJob``. name (:class:`str`): @@ -858,6 +869,7 @@ async def delete_data_labeling_job( Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -869,24 +881,22 @@ async def delete_data_labeling_job( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -948,13 +958,14 @@ async def cancel_data_labeling_job( not guaranteed. Args: - request (:class:`~.job_service.CancelDataLabelingJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelDataLabelingJobRequest`): The request object. Request message for [DataLabelingJobService.CancelDataLabelingJob][]. name (:class:`str`): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1015,19 +1026,21 @@ async def create_hyperparameter_tuning_job( r"""Creates a HyperparameterTuningJob Args: - request (:class:`~.job_service.CreateHyperparameterTuningJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateHyperparameterTuningJobRequest`): The request object. Request message for ``JobService.CreateHyperparameterTuningJob``. parent (:class:`str`): Required. The resource name of the Location to create the HyperparameterTuningJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - hyperparameter_tuning_job (:class:`~.gca_hyperparameter_tuning_job.HyperparameterTuningJob`): + hyperparameter_tuning_job (:class:`google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob`): Required. The HyperparameterTuningJob to create. + This corresponds to the ``hyperparameter_tuning_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1039,7 +1052,7 @@ async def create_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.gca_hyperparameter_tuning_job.HyperparameterTuningJob: + google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob: Represents a HyperparameterTuningJob. A HyperparameterTuningJob has a Study specification and multiple CustomJobs @@ -1098,7 +1111,7 @@ async def get_hyperparameter_tuning_job( r"""Gets a HyperparameterTuningJob Args: - request (:class:`~.job_service.GetHyperparameterTuningJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetHyperparameterTuningJobRequest`): The request object. Request message for ``JobService.GetHyperparameterTuningJob``. name (:class:`str`): @@ -1106,6 +1119,7 @@ async def get_hyperparameter_tuning_job( resource. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1117,7 +1131,7 @@ async def get_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.hyperparameter_tuning_job.HyperparameterTuningJob: + google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob: Represents a HyperparameterTuningJob. A HyperparameterTuningJob has a Study specification and multiple CustomJobs @@ -1174,13 +1188,14 @@ async def list_hyperparameter_tuning_jobs( r"""Lists HyperparameterTuningJobs in a Location. Args: - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsRequest`): The request object. Request message for ``JobService.ListHyperparameterTuningJobs``. parent (:class:`str`): Required. The resource name of the Location to list the HyperparameterTuningJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1192,7 +1207,7 @@ async def list_hyperparameter_tuning_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListHyperparameterTuningJobsAsyncPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListHyperparameterTuningJobsAsyncPager: Response message for ``JobService.ListHyperparameterTuningJobs`` @@ -1256,7 +1271,7 @@ async def delete_hyperparameter_tuning_job( r"""Deletes a HyperparameterTuningJob. Args: - request (:class:`~.job_service.DeleteHyperparameterTuningJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteHyperparameterTuningJobRequest`): The request object. Request message for ``JobService.DeleteHyperparameterTuningJob``. name (:class:`str`): @@ -1264,6 +1279,7 @@ async def delete_hyperparameter_tuning_job( resource to be deleted. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1275,24 +1291,22 @@ async def delete_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -1366,7 +1380,7 @@ async def cancel_hyperparameter_tuning_job( is set to ``CANCELLED``. Args: - request (:class:`~.job_service.CancelHyperparameterTuningJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelHyperparameterTuningJobRequest`): The request object. Request message for ``JobService.CancelHyperparameterTuningJob``. name (:class:`str`): @@ -1374,6 +1388,7 @@ async def cancel_hyperparameter_tuning_job( cancel. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1435,19 +1450,21 @@ async def create_batch_prediction_job( once created will right away be attempted to start. Args: - request (:class:`~.job_service.CreateBatchPredictionJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateBatchPredictionJobRequest`): The request object. Request message for ``JobService.CreateBatchPredictionJob``. parent (:class:`str`): Required. The resource name of the Location to create the BatchPredictionJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - batch_prediction_job (:class:`~.gca_batch_prediction_job.BatchPredictionJob`): + batch_prediction_job (:class:`google.cloud.aiplatform_v1beta1.types.BatchPredictionJob`): Required. The BatchPredictionJob to create. + This corresponds to the ``batch_prediction_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1459,14 +1476,13 @@ async def create_batch_prediction_job( sent along with the request as metadata. Returns: - ~.gca_batch_prediction_job.BatchPredictionJob: - A job that uses a - ``Model`` - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. + google.cloud.aiplatform_v1beta1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. """ # Create or coerce a protobuf request object. @@ -1521,7 +1537,7 @@ async def get_batch_prediction_job( r"""Gets a BatchPredictionJob Args: - request (:class:`~.job_service.GetBatchPredictionJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetBatchPredictionJobRequest`): The request object. Request message for ``JobService.GetBatchPredictionJob``. name (:class:`str`): @@ -1529,6 +1545,7 @@ async def get_batch_prediction_job( Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1540,14 +1557,13 @@ async def get_batch_prediction_job( sent along with the request as metadata. Returns: - ~.batch_prediction_job.BatchPredictionJob: - A job that uses a - ``Model`` - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. + google.cloud.aiplatform_v1beta1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. """ # Create or coerce a protobuf request object. @@ -1600,13 +1616,14 @@ async def list_batch_prediction_jobs( r"""Lists BatchPredictionJobs in a Location. Args: - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsRequest`): The request object. Request message for ``JobService.ListBatchPredictionJobs``. parent (:class:`str`): Required. The resource name of the Location to list the BatchPredictionJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1618,7 +1635,7 @@ async def list_batch_prediction_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListBatchPredictionJobsAsyncPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListBatchPredictionJobsAsyncPager: Response message for ``JobService.ListBatchPredictionJobs`` @@ -1683,7 +1700,7 @@ async def delete_batch_prediction_job( jobs that already finished. Args: - request (:class:`~.job_service.DeleteBatchPredictionJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteBatchPredictionJobRequest`): The request object. Request message for ``JobService.DeleteBatchPredictionJob``. name (:class:`str`): @@ -1691,6 +1708,7 @@ async def delete_batch_prediction_job( be deleted. Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1702,24 +1720,22 @@ async def delete_batch_prediction_job( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -1791,7 +1807,7 @@ async def cancel_batch_prediction_job( are not deleted. Args: - request (:class:`~.job_service.CancelBatchPredictionJobRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelBatchPredictionJobRequest`): The request object. Request message for ``JobService.CancelBatchPredictionJob``. name (:class:`str`): @@ -1799,6 +1815,7 @@ async def cancel_batch_prediction_job( Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/client.py b/google/cloud/aiplatform_v1beta1/services/job_service/client.py index a1eb7c38ce..54a53f26db 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/client.py @@ -46,6 +46,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job from google.cloud.aiplatform_v1beta1.types import ( @@ -136,6 +137,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + JobServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -148,7 +165,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + JobServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -271,6 +288,22 @@ def parse_model_path(path: str) -> Dict[str, str]: ) return m.groupdict() if m else {} + @staticmethod + def trial_path(project: str, location: str, study: str, trial: str,) -> str: + """Return a fully-qualified trial string.""" + return "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + + @staticmethod + def parse_trial_path(path: str) -> Dict[str, str]: + """Parse a trial path into its component segments.""" + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/studies/(?P.+?)/trials/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + @staticmethod def common_billing_account_path(billing_account: str,) -> str: """Return a fully-qualified billing_account string.""" @@ -346,10 +379,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.JobServiceTransport]): The + transport (Union[str, JobServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -385,21 +418,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -442,7 +471,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -461,17 +490,18 @@ def create_custom_job( will be attempted to be run. Args: - request (:class:`~.job_service.CreateCustomJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateCustomJobRequest): The request object. Request message for ``JobService.CreateCustomJob``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the CustomJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - custom_job (:class:`~.gca_custom_job.CustomJob`): + custom_job (google.cloud.aiplatform_v1beta1.types.CustomJob): Required. The CustomJob to create. This corresponds to the ``custom_job`` field on the ``request`` instance; if ``request`` is provided, this @@ -484,7 +514,7 @@ def create_custom_job( sent along with the request as metadata. Returns: - ~.gca_custom_job.CustomJob: + google.cloud.aiplatform_v1beta1.types.CustomJob: Represents a job that runs custom workloads such as a Docker container or a Python package. A CustomJob can have @@ -548,12 +578,13 @@ def get_custom_job( r"""Gets a CustomJob. Args: - request (:class:`~.job_service.GetCustomJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetCustomJobRequest): The request object. Request message for ``JobService.GetCustomJob``. - name (:class:`str`): + name (str): Required. The name of the CustomJob resource. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -565,7 +596,7 @@ def get_custom_job( sent along with the request as metadata. Returns: - ~.custom_job.CustomJob: + google.cloud.aiplatform_v1beta1.types.CustomJob: Represents a job that runs custom workloads such as a Docker container or a Python package. A CustomJob can have @@ -627,13 +658,14 @@ def list_custom_jobs( r"""Lists CustomJobs in a Location. Args: - request (:class:`~.job_service.ListCustomJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListCustomJobsRequest): The request object. Request message for ``JobService.ListCustomJobs``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to list the CustomJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -645,7 +677,7 @@ def list_custom_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListCustomJobsPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListCustomJobsPager: Response message for ``JobService.ListCustomJobs`` @@ -710,13 +742,14 @@ def delete_custom_job( r"""Deletes a CustomJob. Args: - request (:class:`~.job_service.DeleteCustomJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteCustomJobRequest): The request object. Request message for ``JobService.DeleteCustomJob``. - name (:class:`str`): + name (str): Required. The name of the CustomJob resource to be deleted. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -728,24 +761,22 @@ def delete_custom_job( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -819,12 +850,13 @@ def cancel_custom_job( is set to ``CANCELLED``. Args: - request (:class:`~.job_service.CancelCustomJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CancelCustomJobRequest): The request object. Request message for ``JobService.CancelCustomJob``. - name (:class:`str`): + name (str): Required. The name of the CustomJob to cancel. Format: ``projects/{project}/locations/{location}/customJobs/{custom_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -886,18 +918,20 @@ def create_data_labeling_job( r"""Creates a DataLabelingJob. Args: - request (:class:`~.job_service.CreateDataLabelingJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateDataLabelingJobRequest): The request object. Request message for [DataLabelingJobService.CreateDataLabelingJob][]. - parent (:class:`str`): + parent (str): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - data_labeling_job (:class:`~.gca_data_labeling_job.DataLabelingJob`): + data_labeling_job (google.cloud.aiplatform_v1beta1.types.DataLabelingJob): Required. The DataLabelingJob to create. + This corresponds to the ``data_labeling_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -909,7 +943,7 @@ def create_data_labeling_job( sent along with the request as metadata. Returns: - ~.gca_data_labeling_job.DataLabelingJob: + google.cloud.aiplatform_v1beta1.types.DataLabelingJob: DataLabelingJob is used to trigger a human labeling job on unlabeled data from the following Dataset: @@ -968,13 +1002,14 @@ def get_data_labeling_job( r"""Gets a DataLabelingJob. Args: - request (:class:`~.job_service.GetDataLabelingJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetDataLabelingJobRequest): The request object. Request message for [DataLabelingJobService.GetDataLabelingJob][]. - name (:class:`str`): + name (str): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -986,7 +1021,7 @@ def get_data_labeling_job( sent along with the request as metadata. Returns: - ~.data_labeling_job.DataLabelingJob: + google.cloud.aiplatform_v1beta1.types.DataLabelingJob: DataLabelingJob is used to trigger a human labeling job on unlabeled data from the following Dataset: @@ -1043,12 +1078,13 @@ def list_data_labeling_jobs( r"""Lists DataLabelingJobs in a Location. Args: - request (:class:`~.job_service.ListDataLabelingJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest): The request object. Request message for [DataLabelingJobService.ListDataLabelingJobs][]. - parent (:class:`str`): + parent (str): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1060,7 +1096,7 @@ def list_data_labeling_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListDataLabelingJobsPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListDataLabelingJobsPager: Response message for ``JobService.ListDataLabelingJobs``. @@ -1125,14 +1161,15 @@ def delete_data_labeling_job( r"""Deletes a DataLabelingJob. Args: - request (:class:`~.job_service.DeleteDataLabelingJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteDataLabelingJobRequest): The request object. Request message for ``JobService.DeleteDataLabelingJob``. - name (:class:`str`): + name (str): Required. The name of the DataLabelingJob to be deleted. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1144,24 +1181,22 @@ def delete_data_labeling_job( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -1224,13 +1259,14 @@ def cancel_data_labeling_job( not guaranteed. Args: - request (:class:`~.job_service.CancelDataLabelingJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CancelDataLabelingJobRequest): The request object. Request message for [DataLabelingJobService.CancelDataLabelingJob][]. - name (:class:`str`): + name (str): Required. The name of the DataLabelingJob. Format: ``projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1292,19 +1328,21 @@ def create_hyperparameter_tuning_job( r"""Creates a HyperparameterTuningJob Args: - request (:class:`~.job_service.CreateHyperparameterTuningJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateHyperparameterTuningJobRequest): The request object. Request message for ``JobService.CreateHyperparameterTuningJob``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the HyperparameterTuningJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - hyperparameter_tuning_job (:class:`~.gca_hyperparameter_tuning_job.HyperparameterTuningJob`): + hyperparameter_tuning_job (google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob): Required. The HyperparameterTuningJob to create. + This corresponds to the ``hyperparameter_tuning_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1316,7 +1354,7 @@ def create_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.gca_hyperparameter_tuning_job.HyperparameterTuningJob: + google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob: Represents a HyperparameterTuningJob. A HyperparameterTuningJob has a Study specification and multiple CustomJobs @@ -1378,14 +1416,15 @@ def get_hyperparameter_tuning_job( r"""Gets a HyperparameterTuningJob Args: - request (:class:`~.job_service.GetHyperparameterTuningJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetHyperparameterTuningJobRequest): The request object. Request message for ``JobService.GetHyperparameterTuningJob``. - name (:class:`str`): + name (str): Required. The name of the HyperparameterTuningJob resource. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1397,7 +1436,7 @@ def get_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.hyperparameter_tuning_job.HyperparameterTuningJob: + google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob: Represents a HyperparameterTuningJob. A HyperparameterTuningJob has a Study specification and multiple CustomJobs @@ -1457,13 +1496,14 @@ def list_hyperparameter_tuning_jobs( r"""Lists HyperparameterTuningJobs in a Location. Args: - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsRequest): The request object. Request message for ``JobService.ListHyperparameterTuningJobs``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to list the HyperparameterTuningJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1475,7 +1515,7 @@ def list_hyperparameter_tuning_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListHyperparameterTuningJobsPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListHyperparameterTuningJobsPager: Response message for ``JobService.ListHyperparameterTuningJobs`` @@ -1542,14 +1582,15 @@ def delete_hyperparameter_tuning_job( r"""Deletes a HyperparameterTuningJob. Args: - request (:class:`~.job_service.DeleteHyperparameterTuningJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteHyperparameterTuningJobRequest): The request object. Request message for ``JobService.DeleteHyperparameterTuningJob``. - name (:class:`str`): + name (str): Required. The name of the HyperparameterTuningJob resource to be deleted. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1561,24 +1602,22 @@ def delete_hyperparameter_tuning_job( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -1655,14 +1694,15 @@ def cancel_hyperparameter_tuning_job( is set to ``CANCELLED``. Args: - request (:class:`~.job_service.CancelHyperparameterTuningJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CancelHyperparameterTuningJobRequest): The request object. Request message for ``JobService.CancelHyperparameterTuningJob``. - name (:class:`str`): + name (str): Required. The name of the HyperparameterTuningJob to cancel. Format: ``projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1727,19 +1767,21 @@ def create_batch_prediction_job( once created will right away be attempted to start. Args: - request (:class:`~.job_service.CreateBatchPredictionJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateBatchPredictionJobRequest): The request object. Request message for ``JobService.CreateBatchPredictionJob``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the BatchPredictionJob in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - batch_prediction_job (:class:`~.gca_batch_prediction_job.BatchPredictionJob`): + batch_prediction_job (google.cloud.aiplatform_v1beta1.types.BatchPredictionJob): Required. The BatchPredictionJob to create. + This corresponds to the ``batch_prediction_job`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1751,14 +1793,13 @@ def create_batch_prediction_job( sent along with the request as metadata. Returns: - ~.gca_batch_prediction_job.BatchPredictionJob: - A job that uses a - ``Model`` - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. + google.cloud.aiplatform_v1beta1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. """ # Create or coerce a protobuf request object. @@ -1816,14 +1857,15 @@ def get_batch_prediction_job( r"""Gets a BatchPredictionJob Args: - request (:class:`~.job_service.GetBatchPredictionJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetBatchPredictionJobRequest): The request object. Request message for ``JobService.GetBatchPredictionJob``. - name (:class:`str`): + name (str): Required. The name of the BatchPredictionJob resource. Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1835,14 +1877,13 @@ def get_batch_prediction_job( sent along with the request as metadata. Returns: - ~.batch_prediction_job.BatchPredictionJob: - A job that uses a - ``Model`` - to produce predictions on multiple [input - instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. - If predictions for significant portion of the instances - fail, the job may finish without attempting predictions - for all remaining instances. + google.cloud.aiplatform_v1beta1.types.BatchPredictionJob: + A job that uses a ``Model`` to produce predictions + on multiple [input + instances][google.cloud.aiplatform.v1beta1.BatchPredictionJob.input_config]. + If predictions for significant portion of the + instances fail, the job may finish without attempting + predictions for all remaining instances. """ # Create or coerce a protobuf request object. @@ -1896,13 +1937,14 @@ def list_batch_prediction_jobs( r"""Lists BatchPredictionJobs in a Location. Args: - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsRequest): The request object. Request message for ``JobService.ListBatchPredictionJobs``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to list the BatchPredictionJobs from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1914,7 +1956,7 @@ def list_batch_prediction_jobs( sent along with the request as metadata. Returns: - ~.pagers.ListBatchPredictionJobsPager: + google.cloud.aiplatform_v1beta1.services.job_service.pagers.ListBatchPredictionJobsPager: Response message for ``JobService.ListBatchPredictionJobs`` @@ -1982,14 +2024,15 @@ def delete_batch_prediction_job( jobs that already finished. Args: - request (:class:`~.job_service.DeleteBatchPredictionJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteBatchPredictionJobRequest): The request object. Request message for ``JobService.DeleteBatchPredictionJob``. - name (:class:`str`): + name (str): Required. The name of the BatchPredictionJob resource to be deleted. Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -2001,24 +2044,22 @@ def delete_batch_prediction_job( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -2093,14 +2134,15 @@ def cancel_batch_prediction_job( are not deleted. Args: - request (:class:`~.job_service.CancelBatchPredictionJobRequest`): + request (google.cloud.aiplatform_v1beta1.types.CancelBatchPredictionJobRequest): The request object. Request message for ``JobService.CancelBatchPredictionJob``. - name (:class:`str`): + name (str): Required. The name of the BatchPredictionJob to cancel. Format: ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py index 05e5be73ca..845939923f 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/pagers.py @@ -28,7 +28,7 @@ class ListCustomJobsPager: """A pager for iterating through ``list_custom_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListCustomJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse` object, and provides an ``__iter__`` method to iterate through its ``custom_jobs`` field. @@ -37,7 +37,7 @@ class ListCustomJobsPager: through the ``custom_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListCustomJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -55,9 +55,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListCustomJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListCustomJobsRequest): The initial request object. - response (:class:`~.job_service.ListCustomJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -90,7 +90,7 @@ class ListCustomJobsAsyncPager: """A pager for iterating through ``list_custom_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListCustomJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse` object, and provides an ``__aiter__`` method to iterate through its ``custom_jobs`` field. @@ -99,7 +99,7 @@ class ListCustomJobsAsyncPager: through the ``custom_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListCustomJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -117,9 +117,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListCustomJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListCustomJobsRequest): The initial request object. - response (:class:`~.job_service.ListCustomJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListCustomJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -156,7 +156,7 @@ class ListDataLabelingJobsPager: """A pager for iterating through ``list_data_labeling_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListDataLabelingJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse` object, and provides an ``__iter__`` method to iterate through its ``data_labeling_jobs`` field. @@ -165,7 +165,7 @@ class ListDataLabelingJobsPager: through the ``data_labeling_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListDataLabelingJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -183,9 +183,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListDataLabelingJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest): The initial request object. - response (:class:`~.job_service.ListDataLabelingJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -218,7 +218,7 @@ class ListDataLabelingJobsAsyncPager: """A pager for iterating through ``list_data_labeling_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListDataLabelingJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse` object, and provides an ``__aiter__`` method to iterate through its ``data_labeling_jobs`` field. @@ -227,7 +227,7 @@ class ListDataLabelingJobsAsyncPager: through the ``data_labeling_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListDataLabelingJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -245,9 +245,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListDataLabelingJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsRequest): The initial request object. - response (:class:`~.job_service.ListDataLabelingJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListDataLabelingJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -284,7 +284,7 @@ class ListHyperparameterTuningJobsPager: """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListHyperparameterTuningJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse` object, and provides an ``__iter__`` method to iterate through its ``hyperparameter_tuning_jobs`` field. @@ -293,7 +293,7 @@ class ListHyperparameterTuningJobsPager: through the ``hyperparameter_tuning_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListHyperparameterTuningJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -311,9 +311,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsRequest): The initial request object. - response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -346,7 +346,7 @@ class ListHyperparameterTuningJobsAsyncPager: """A pager for iterating through ``list_hyperparameter_tuning_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListHyperparameterTuningJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse` object, and provides an ``__aiter__`` method to iterate through its ``hyperparameter_tuning_jobs`` field. @@ -355,7 +355,7 @@ class ListHyperparameterTuningJobsAsyncPager: through the ``hyperparameter_tuning_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListHyperparameterTuningJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -375,9 +375,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListHyperparameterTuningJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsRequest): The initial request object. - response (:class:`~.job_service.ListHyperparameterTuningJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListHyperparameterTuningJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -418,7 +418,7 @@ class ListBatchPredictionJobsPager: """A pager for iterating through ``list_batch_prediction_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListBatchPredictionJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse` object, and provides an ``__iter__`` method to iterate through its ``batch_prediction_jobs`` field. @@ -427,7 +427,7 @@ class ListBatchPredictionJobsPager: through the ``batch_prediction_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListBatchPredictionJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -445,9 +445,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsRequest): The initial request object. - response (:class:`~.job_service.ListBatchPredictionJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -480,7 +480,7 @@ class ListBatchPredictionJobsAsyncPager: """A pager for iterating through ``list_batch_prediction_jobs`` requests. This class thinly wraps an initial - :class:`~.job_service.ListBatchPredictionJobsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse` object, and provides an ``__aiter__`` method to iterate through its ``batch_prediction_jobs`` field. @@ -489,7 +489,7 @@ class ListBatchPredictionJobsAsyncPager: through the ``batch_prediction_jobs`` field on the corresponding responses. - All the usual :class:`~.job_service.ListBatchPredictionJobsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -507,9 +507,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.job_service.ListBatchPredictionJobsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsRequest): The initial request object. - response (:class:`~.job_service.ListBatchPredictionJobsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListBatchPredictionJobsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py index ca4d929cb5..349bfbcdea 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = JobServiceGrpcTransport _transport_registry["grpc_asyncio"] = JobServiceGrpcAsyncIOTransport - __all__ = ( "JobServiceTransport", "JobServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py index abedda51f9..3d1f0be59b 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/base.py @@ -86,10 +86,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py index 859efdd7e7..c4efeaaf47 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc.py @@ -74,6 +74,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -104,6 +105,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -120,6 +125,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -129,11 +139,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -163,6 +168,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -173,17 +182,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -207,7 +227,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -242,7 +262,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -253,13 +274,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_custom_job( diff --git a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py index 7d203c8d18..07655ba262 100644 --- a/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/job_service/transports/grpc_asyncio.py @@ -118,6 +118,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -149,12 +150,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -165,6 +170,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -174,11 +184,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -208,6 +213,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -218,14 +227,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -239,6 +258,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -258,13 +278,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_custom_job( diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py index af13c4d4fb..7577e15d1c 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/async_client.py @@ -96,6 +96,7 @@ class MigrationServiceAsyncClient: MigrationServiceClient.parse_common_location_path ) + from_service_account_info = MigrationServiceClient.from_service_account_info from_service_account_file = MigrationServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -175,7 +176,7 @@ async def search_migratable_resources( given location. Args: - request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesRequest`): The request object. Request message for ``MigrationService.SearchMigratableResources``. parent (:class:`str`): @@ -184,6 +185,7 @@ async def search_migratable_resources( that the resources can be migrated to, not the resources' original location. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -195,7 +197,7 @@ async def search_migratable_resources( sent along with the request as metadata. Returns: - ~.pagers.SearchMigratableResourcesAsyncPager: + google.cloud.aiplatform_v1beta1.services.migration_service.pagers.SearchMigratableResourcesAsyncPager: Response message for ``MigrationService.SearchMigratableResources``. @@ -264,22 +266,24 @@ async def batch_migrate_resources( to AI Platform (Unified). Args: - request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesRequest`): The request object. Request message for ``MigrationService.BatchMigrateResources``. parent (:class:`str`): Required. The location of the migrated resource will live in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + migrate_resource_requests (:class:`Sequence[google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest]`): Required. The request messages specifying the resources to migrate. They must be in the same location as the destination. Up to 50 resources can be migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -291,11 +295,11 @@ async def batch_migrate_resources( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.migration_service.BatchMigrateResourcesResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesResponse` Response message for ``MigrationService.BatchMigrateResources``. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py index bf1f8e5c6b..6d88a39046 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/client.py @@ -116,6 +116,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + MigrationServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -128,7 +144,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + MigrationServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -331,10 +347,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.MigrationServiceTransport]): The + transport (Union[str, MigrationServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -370,21 +386,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -427,7 +439,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -447,15 +459,16 @@ def search_migratable_resources( given location. Args: - request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + request (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesRequest): The request object. Request message for ``MigrationService.SearchMigratableResources``. - parent (:class:`str`): + parent (str): Required. The location that the migratable resources should be searched from. It's the AI Platform location that the resources can be migrated to, not the resources' original location. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -467,7 +480,7 @@ def search_migratable_resources( sent along with the request as metadata. Returns: - ~.pagers.SearchMigratableResourcesPager: + google.cloud.aiplatform_v1beta1.services.migration_service.pagers.SearchMigratableResourcesPager: Response message for ``MigrationService.SearchMigratableResources``. @@ -539,22 +552,24 @@ def batch_migrate_resources( to AI Platform (Unified). Args: - request (:class:`~.migration_service.BatchMigrateResourcesRequest`): + request (google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesRequest): The request object. Request message for ``MigrationService.BatchMigrateResources``. - parent (:class:`str`): + parent (str): Required. The location of the migrated resource will live in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - migrate_resource_requests (:class:`Sequence[~.migration_service.MigrateResourceRequest]`): + migrate_resource_requests (Sequence[google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest]): Required. The request messages specifying the resources to migrate. They must be in the same location as the destination. Up to 50 resources can be migrated in one batch. + This corresponds to the ``migrate_resource_requests`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -566,11 +581,11 @@ def batch_migrate_resources( sent along with the request as metadata. Returns: - ~.operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.migration_service.BatchMigrateResourcesResponse`: + :class:`google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesResponse` Response message for ``MigrationService.BatchMigrateResources``. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py index 826325f2e4..d231e61235 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/pagers.py @@ -25,7 +25,7 @@ class SearchMigratableResourcesPager: """A pager for iterating through ``search_migratable_resources`` requests. This class thinly wraps an initial - :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse` object, and provides an ``__iter__`` method to iterate through its ``migratable_resources`` field. @@ -34,7 +34,7 @@ class SearchMigratableResourcesPager: through the ``migratable_resources`` field on the corresponding responses. - All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -52,9 +52,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + request (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesRequest): The initial request object. - response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + response (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -87,7 +87,7 @@ class SearchMigratableResourcesAsyncPager: """A pager for iterating through ``search_migratable_resources`` requests. This class thinly wraps an initial - :class:`~.migration_service.SearchMigratableResourcesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse` object, and provides an ``__aiter__`` method to iterate through its ``migratable_resources`` field. @@ -96,7 +96,7 @@ class SearchMigratableResourcesAsyncPager: through the ``migratable_resources`` field on the corresponding responses. - All the usual :class:`~.migration_service.SearchMigratableResourcesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -116,9 +116,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.migration_service.SearchMigratableResourcesRequest`): + request (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesRequest): The initial request object. - response (:class:`~.migration_service.SearchMigratableResourcesResponse`): + response (google.cloud.aiplatform_v1beta1.types.SearchMigratableResourcesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py index af727857e7..38c72756f6 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = MigrationServiceGrpcTransport _transport_registry["grpc_asyncio"] = MigrationServiceGrpcAsyncIOTransport - __all__ = ( "MigrationServiceTransport", "MigrationServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py index e5feef70cf..cbcb288489 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/base.py @@ -71,10 +71,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py index b73e3936d5..0c5e1a080e 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc.py @@ -61,6 +61,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -91,6 +92,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -107,6 +112,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -116,11 +126,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -150,6 +155,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -160,17 +169,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -194,7 +214,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -229,7 +249,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -240,13 +261,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def search_migratable_resources( diff --git a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py index 969d1a3b12..33e96e7170 100644 --- a/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/migration_service/transports/grpc_asyncio.py @@ -105,6 +105,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -136,12 +137,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -152,6 +157,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -161,11 +171,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -195,6 +200,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -205,14 +214,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -226,6 +245,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -245,13 +265,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def search_migratable_resources( diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py index 3b27b6e184..c6b2b5f0fc 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/async_client.py @@ -32,6 +32,7 @@ from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model as gca_model @@ -101,6 +102,7 @@ class ModelServiceAsyncClient: ModelServiceClient.parse_common_location_path ) + from_service_account_info = ModelServiceClient.from_service_account_info from_service_account_file = ModelServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -178,17 +180,18 @@ async def upload_model( r"""Uploads a Model artifact into AI Platform. Args: - request (:class:`~.model_service.UploadModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UploadModelRequest`): The request object. Request message for ``ModelService.UploadModel``. parent (:class:`str`): Required. The resource name of the Location into which to upload the Model. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - model (:class:`~.gca_model.Model`): + model (:class:`google.cloud.aiplatform_v1beta1.types.Model`): Required. The Model to create. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -201,12 +204,12 @@ async def upload_model( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.model_service.UploadModelResponse`: Response - message of + :class:`google.cloud.aiplatform_v1beta1.types.UploadModelResponse` + Response message of ``ModelService.UploadModel`` operation. @@ -271,12 +274,13 @@ async def get_model( r"""Gets a Model. Args: - request (:class:`~.model_service.GetModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetModelRequest`): The request object. Request message for ``ModelService.GetModel``. name (:class:`str`): Required. The name of the Model resource. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -288,7 +292,7 @@ async def get_model( sent along with the request as metadata. Returns: - ~.model.Model: + google.cloud.aiplatform_v1beta1.types.Model: A trained machine learning Model. """ # Create or coerce a protobuf request object. @@ -341,13 +345,14 @@ async def list_models( r"""Lists Models in a Location. Args: - request (:class:`~.model_service.ListModelsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListModelsRequest`): The request object. Request message for ``ModelService.ListModels``. parent (:class:`str`): Required. The resource name of the Location to list the Models from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -359,7 +364,7 @@ async def list_models( sent along with the request as metadata. Returns: - ~.pagers.ListModelsAsyncPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelsAsyncPager: Response message for ``ModelService.ListModels`` @@ -424,20 +429,21 @@ async def update_model( r"""Updates a Model. Args: - request (:class:`~.model_service.UpdateModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateModelRequest`): The request object. Request message for ``ModelService.UpdateModel``. - model (:class:`~.gca_model.Model`): + model (:class:`google.cloud.aiplatform_v1beta1.types.Model`): Required. The Model which replaces the resource on the server. + This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see + `FieldMask `__. - [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -449,7 +455,7 @@ async def update_model( sent along with the request as metadata. Returns: - ~.gca_model.Model: + google.cloud.aiplatform_v1beta1.types.Model: A trained machine learning Model. """ # Create or coerce a protobuf request object. @@ -508,13 +514,14 @@ async def delete_model( DeployedModels created from it. Args: - request (:class:`~.model_service.DeleteModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteModelRequest`): The request object. Request message for ``ModelService.DeleteModel``. name (:class:`str`): Required. The name of the Model resource to be deleted. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -526,24 +533,22 @@ async def delete_model( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -608,19 +613,21 @@ async def export_model( format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. Args: - request (:class:`~.model_service.ExportModelRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ExportModelRequest`): The request object. Request message for ``ModelService.ExportModel``. name (:class:`str`): Required. The resource name of the Model to export. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - output_config (:class:`~.model_service.ExportModelRequest.OutputConfig`): + output_config (:class:`google.cloud.aiplatform_v1beta1.types.ExportModelRequest.OutputConfig`): Required. The desired output location and configuration. + This corresponds to the ``output_config`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -632,12 +639,12 @@ async def export_model( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. The result type for the operation will be - :class:`~.model_service.ExportModelResponse`: Response - message of + :class:`google.cloud.aiplatform_v1beta1.types.ExportModelResponse` + Response message of ``ModelService.ExportModel`` operation. @@ -702,7 +709,7 @@ async def get_model_evaluation( r"""Gets a ModelEvaluation. Args: - request (:class:`~.model_service.GetModelEvaluationRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetModelEvaluationRequest`): The request object. Request message for ``ModelService.GetModelEvaluation``. name (:class:`str`): @@ -710,6 +717,7 @@ async def get_model_evaluation( Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -721,7 +729,7 @@ async def get_model_evaluation( sent along with the request as metadata. Returns: - ~.model_evaluation.ModelEvaluation: + google.cloud.aiplatform_v1beta1.types.ModelEvaluation: A collection of metrics calculated by comparing Model's predictions on all of the test data against annotations from @@ -778,13 +786,14 @@ async def list_model_evaluations( r"""Lists ModelEvaluations in a Model. Args: - request (:class:`~.model_service.ListModelEvaluationsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsRequest`): The request object. Request message for ``ModelService.ListModelEvaluations``. parent (:class:`str`): Required. The resource name of the Model to list the ModelEvaluations from. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -796,7 +805,7 @@ async def list_model_evaluations( sent along with the request as metadata. Returns: - ~.pagers.ListModelEvaluationsAsyncPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelEvaluationsAsyncPager: Response message for ``ModelService.ListModelEvaluations``. @@ -860,7 +869,7 @@ async def get_model_evaluation_slice( r"""Gets a ModelEvaluationSlice. Args: - request (:class:`~.model_service.GetModelEvaluationSliceRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetModelEvaluationSliceRequest`): The request object. Request message for ``ModelService.GetModelEvaluationSlice``. name (:class:`str`): @@ -868,6 +877,7 @@ async def get_model_evaluation_slice( Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -879,7 +889,7 @@ async def get_model_evaluation_slice( sent along with the request as metadata. Returns: - ~.model_evaluation_slice.ModelEvaluationSlice: + google.cloud.aiplatform_v1beta1.types.ModelEvaluationSlice: A collection of metrics calculated by comparing Model's predictions on a slice of the test data against ground truth @@ -936,7 +946,7 @@ async def list_model_evaluation_slices( r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesRequest`): The request object. Request message for ``ModelService.ListModelEvaluationSlices``. parent (:class:`str`): @@ -944,6 +954,7 @@ async def list_model_evaluation_slices( list the ModelEvaluationSlices from. Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -955,7 +966,7 @@ async def list_model_evaluation_slices( sent along with the request as metadata. Returns: - ~.pagers.ListModelEvaluationSlicesAsyncPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelEvaluationSlicesAsyncPager: Response message for ``ModelService.ListModelEvaluationSlices``. diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/client.py b/google/cloud/aiplatform_v1beta1/services/model_service/client.py index 30c00c0c9d..d357ad3b9a 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/client.py @@ -36,6 +36,7 @@ from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import model as gca_model @@ -121,6 +122,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -133,7 +150,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + ModelServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -315,10 +332,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.ModelServiceTransport]): The + transport (Union[str, ModelServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -354,21 +371,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -411,7 +424,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -429,17 +442,18 @@ def upload_model( r"""Uploads a Model artifact into AI Platform. Args: - request (:class:`~.model_service.UploadModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.UploadModelRequest): The request object. Request message for ``ModelService.UploadModel``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location into which to upload the Model. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - model (:class:`~.gca_model.Model`): + model (google.cloud.aiplatform_v1beta1.types.Model): Required. The Model to create. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -452,12 +466,12 @@ def upload_model( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.model_service.UploadModelResponse`: Response - message of + :class:`google.cloud.aiplatform_v1beta1.types.UploadModelResponse` + Response message of ``ModelService.UploadModel`` operation. @@ -523,12 +537,13 @@ def get_model( r"""Gets a Model. Args: - request (:class:`~.model_service.GetModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetModelRequest): The request object. Request message for ``ModelService.GetModel``. - name (:class:`str`): + name (str): Required. The name of the Model resource. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -540,7 +555,7 @@ def get_model( sent along with the request as metadata. Returns: - ~.model.Model: + google.cloud.aiplatform_v1beta1.types.Model: A trained machine learning Model. """ # Create or coerce a protobuf request object. @@ -594,13 +609,14 @@ def list_models( r"""Lists Models in a Location. Args: - request (:class:`~.model_service.ListModelsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelsRequest): The request object. Request message for ``ModelService.ListModels``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to list the Models from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -612,7 +628,7 @@ def list_models( sent along with the request as metadata. Returns: - ~.pagers.ListModelsPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelsPager: Response message for ``ModelService.ListModels`` @@ -678,20 +694,21 @@ def update_model( r"""Updates a Model. Args: - request (:class:`~.model_service.UpdateModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.UpdateModelRequest): The request object. Request message for ``ModelService.UpdateModel``. - model (:class:`~.gca_model.Model`): + model (google.cloud.aiplatform_v1beta1.types.Model): Required. The Model which replaces the resource on the server. + This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see + `FieldMask `__. - [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -703,7 +720,7 @@ def update_model( sent along with the request as metadata. Returns: - ~.gca_model.Model: + google.cloud.aiplatform_v1beta1.types.Model: A trained machine learning Model. """ # Create or coerce a protobuf request object. @@ -763,13 +780,14 @@ def delete_model( DeployedModels created from it. Args: - request (:class:`~.model_service.DeleteModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteModelRequest): The request object. Request message for ``ModelService.DeleteModel``. - name (:class:`str`): + name (str): Required. The name of the Model resource to be deleted. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -781,24 +799,22 @@ def delete_model( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -864,19 +880,21 @@ def export_model( format][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. Args: - request (:class:`~.model_service.ExportModelRequest`): + request (google.cloud.aiplatform_v1beta1.types.ExportModelRequest): The request object. Request message for ``ModelService.ExportModel``. - name (:class:`str`): + name (str): Required. The resource name of the Model to export. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - output_config (:class:`~.model_service.ExportModelRequest.OutputConfig`): + output_config (google.cloud.aiplatform_v1beta1.types.ExportModelRequest.OutputConfig): Required. The desired output location and configuration. + This corresponds to the ``output_config`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -888,12 +906,12 @@ def export_model( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. The result type for the operation will be - :class:`~.model_service.ExportModelResponse`: Response - message of + :class:`google.cloud.aiplatform_v1beta1.types.ExportModelResponse` + Response message of ``ModelService.ExportModel`` operation. @@ -959,14 +977,15 @@ def get_model_evaluation( r"""Gets a ModelEvaluation. Args: - request (:class:`~.model_service.GetModelEvaluationRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetModelEvaluationRequest): The request object. Request message for ``ModelService.GetModelEvaluation``. - name (:class:`str`): + name (str): Required. The name of the ModelEvaluation resource. Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -978,7 +997,7 @@ def get_model_evaluation( sent along with the request as metadata. Returns: - ~.model_evaluation.ModelEvaluation: + google.cloud.aiplatform_v1beta1.types.ModelEvaluation: A collection of metrics calculated by comparing Model's predictions on all of the test data against annotations from @@ -1036,13 +1055,14 @@ def list_model_evaluations( r"""Lists ModelEvaluations in a Model. Args: - request (:class:`~.model_service.ListModelEvaluationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsRequest): The request object. Request message for ``ModelService.ListModelEvaluations``. - parent (:class:`str`): + parent (str): Required. The resource name of the Model to list the ModelEvaluations from. Format: ``projects/{project}/locations/{location}/models/{model}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1054,7 +1074,7 @@ def list_model_evaluations( sent along with the request as metadata. Returns: - ~.pagers.ListModelEvaluationsPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelEvaluationsPager: Response message for ``ModelService.ListModelEvaluations``. @@ -1119,14 +1139,15 @@ def get_model_evaluation_slice( r"""Gets a ModelEvaluationSlice. Args: - request (:class:`~.model_service.GetModelEvaluationSliceRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetModelEvaluationSliceRequest): The request object. Request message for ``ModelService.GetModelEvaluationSlice``. - name (:class:`str`): + name (str): Required. The name of the ModelEvaluationSlice resource. Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1138,7 +1159,7 @@ def get_model_evaluation_slice( sent along with the request as metadata. Returns: - ~.model_evaluation_slice.ModelEvaluationSlice: + google.cloud.aiplatform_v1beta1.types.ModelEvaluationSlice: A collection of metrics calculated by comparing Model's predictions on a slice of the test data against ground truth @@ -1198,14 +1219,15 @@ def list_model_evaluation_slices( r"""Lists ModelEvaluationSlices in a ModelEvaluation. Args: - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesRequest): The request object. Request message for ``ModelService.ListModelEvaluationSlices``. - parent (:class:`str`): + parent (str): Required. The resource name of the ModelEvaluation to list the ModelEvaluationSlices from. Format: ``projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -1217,7 +1239,7 @@ def list_model_evaluation_slices( sent along with the request as metadata. Returns: - ~.pagers.ListModelEvaluationSlicesPager: + google.cloud.aiplatform_v1beta1.services.model_service.pagers.ListModelEvaluationSlicesPager: Response message for ``ModelService.ListModelEvaluationSlices``. diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py index 1ab3aacb91..046f462b45 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/pagers.py @@ -27,7 +27,7 @@ class ListModelsPager: """A pager for iterating through ``list_models`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelsResponse` object, and provides an ``__iter__`` method to iterate through its ``models`` field. @@ -36,7 +36,7 @@ class ListModelsPager: through the ``models`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -54,9 +54,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelsRequest): The initial request object. - response (:class:`~.model_service.ListModelsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -89,7 +89,7 @@ class ListModelsAsyncPager: """A pager for iterating through ``list_models`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelsResponse` object, and provides an ``__aiter__`` method to iterate through its ``models`` field. @@ -98,7 +98,7 @@ class ListModelsAsyncPager: through the ``models`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -116,9 +116,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelsRequest): The initial request object. - response (:class:`~.model_service.ListModelsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -155,7 +155,7 @@ class ListModelEvaluationsPager: """A pager for iterating through ``list_model_evaluations`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse` object, and provides an ``__iter__`` method to iterate through its ``model_evaluations`` field. @@ -164,7 +164,7 @@ class ListModelEvaluationsPager: through the ``model_evaluations`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelEvaluationsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -182,9 +182,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsRequest): The initial request object. - response (:class:`~.model_service.ListModelEvaluationsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -217,7 +217,7 @@ class ListModelEvaluationsAsyncPager: """A pager for iterating through ``list_model_evaluations`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse` object, and provides an ``__aiter__`` method to iterate through its ``model_evaluations`` field. @@ -226,7 +226,7 @@ class ListModelEvaluationsAsyncPager: through the ``model_evaluations`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelEvaluationsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -244,9 +244,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsRequest): The initial request object. - response (:class:`~.model_service.ListModelEvaluationsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -283,7 +283,7 @@ class ListModelEvaluationSlicesPager: """A pager for iterating through ``list_model_evaluation_slices`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationSlicesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse` object, and provides an ``__iter__`` method to iterate through its ``model_evaluation_slices`` field. @@ -292,7 +292,7 @@ class ListModelEvaluationSlicesPager: through the ``model_evaluation_slices`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelEvaluationSlicesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -310,9 +310,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesRequest): The initial request object. - response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -345,7 +345,7 @@ class ListModelEvaluationSlicesAsyncPager: """A pager for iterating through ``list_model_evaluation_slices`` requests. This class thinly wraps an initial - :class:`~.model_service.ListModelEvaluationSlicesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse` object, and provides an ``__aiter__`` method to iterate through its ``model_evaluation_slices`` field. @@ -354,7 +354,7 @@ class ListModelEvaluationSlicesAsyncPager: through the ``model_evaluation_slices`` field on the corresponding responses. - All the usual :class:`~.model_service.ListModelEvaluationSlicesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -374,9 +374,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.model_service.ListModelEvaluationSlicesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesRequest): The initial request object. - response (:class:`~.model_service.ListModelEvaluationSlicesResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListModelEvaluationSlicesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py index a521df9229..5d1cb51abc 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = ModelServiceGrpcTransport _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport - __all__ = ( "ModelServiceTransport", "ModelServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py index a0b896cdf4..2f87fc98dd 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/base.py @@ -75,10 +75,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py index 98f90e9dc8..39452c0cd6 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc.py @@ -63,6 +63,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -93,6 +94,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -109,6 +114,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -118,11 +128,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -152,6 +157,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -162,17 +171,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -196,7 +216,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -231,7 +251,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -242,13 +263,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def upload_model( diff --git a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py index bce1fed9a6..d05bebeeec 100644 --- a/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/model_service/transports/grpc_asyncio.py @@ -107,6 +107,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -138,12 +139,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -154,6 +159,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -163,11 +173,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -197,6 +202,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -207,14 +216,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -228,6 +247,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -247,13 +267,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def upload_model( diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py index ef420aae0b..170b5b8f59 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/async_client.py @@ -31,6 +31,7 @@ from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_service @@ -95,6 +96,7 @@ class PipelineServiceAsyncClient: PipelineServiceClient.parse_common_location_path ) + from_service_account_info = PipelineServiceClient.from_service_account_info from_service_account_file = PipelineServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -173,19 +175,21 @@ async def create_training_pipeline( TrainingPipeline right away will be attempted to be run. Args: - request (:class:`~.pipeline_service.CreateTrainingPipelineRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateTrainingPipelineRequest`): The request object. Request message for ``PipelineService.CreateTrainingPipeline``. parent (:class:`str`): Required. The resource name of the Location to create the TrainingPipeline in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - training_pipeline (:class:`~.gca_training_pipeline.TrainingPipeline`): + training_pipeline (:class:`google.cloud.aiplatform_v1beta1.types.TrainingPipeline`): Required. The TrainingPipeline to create. + This corresponds to the ``training_pipeline`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -197,13 +201,13 @@ async def create_training_pipeline( sent along with the request as metadata. Returns: - ~.gca_training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - ``upload`` - the Model to AI Platform, and evaluate the Model. + google.cloud.aiplatform_v1beta1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. """ # Create or coerce a protobuf request object. @@ -258,7 +262,7 @@ async def get_training_pipeline( r"""Gets a TrainingPipeline. Args: - request (:class:`~.pipeline_service.GetTrainingPipelineRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetTrainingPipelineRequest`): The request object. Request message for ``PipelineService.GetTrainingPipeline``. name (:class:`str`): @@ -266,6 +270,7 @@ async def get_training_pipeline( Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -277,13 +282,13 @@ async def get_training_pipeline( sent along with the request as metadata. Returns: - ~.training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - ``upload`` - the Model to AI Platform, and evaluate the Model. + google.cloud.aiplatform_v1beta1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. """ # Create or coerce a protobuf request object. @@ -336,13 +341,14 @@ async def list_training_pipelines( r"""Lists TrainingPipelines in a Location. Args: - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesRequest`): The request object. Request message for ``PipelineService.ListTrainingPipelines``. parent (:class:`str`): Required. The resource name of the Location to list the TrainingPipelines from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -354,7 +360,7 @@ async def list_training_pipelines( sent along with the request as metadata. Returns: - ~.pagers.ListTrainingPipelinesAsyncPager: + google.cloud.aiplatform_v1beta1.services.pipeline_service.pagers.ListTrainingPipelinesAsyncPager: Response message for ``PipelineService.ListTrainingPipelines`` @@ -418,7 +424,7 @@ async def delete_training_pipeline( r"""Deletes a TrainingPipeline. Args: - request (:class:`~.pipeline_service.DeleteTrainingPipelineRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteTrainingPipelineRequest`): The request object. Request message for ``PipelineService.DeleteTrainingPipeline``. name (:class:`str`): @@ -426,6 +432,7 @@ async def delete_training_pipeline( be deleted. Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -437,24 +444,22 @@ async def delete_training_pipeline( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -527,7 +532,7 @@ async def cancel_training_pipeline( is set to ``CANCELLED``. Args: - request (:class:`~.pipeline_service.CancelTrainingPipelineRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CancelTrainingPipelineRequest`): The request object. Request message for ``PipelineService.CancelTrainingPipeline``. name (:class:`str`): @@ -535,6 +540,7 @@ async def cancel_training_pipeline( Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py index e3e7d6aeda..25aa10df28 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/client.py @@ -35,6 +35,7 @@ from google.api_core import operation as ga_operation # type: ignore from google.api_core import operation_async # type: ignore from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.cloud.aiplatform_v1beta1.types import pipeline_service @@ -123,6 +124,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PipelineServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -135,7 +152,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + PipelineServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -277,10 +294,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PipelineServiceTransport]): The + transport (Union[str, PipelineServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -316,21 +333,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -373,7 +386,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -392,19 +405,21 @@ def create_training_pipeline( TrainingPipeline right away will be attempted to be run. Args: - request (:class:`~.pipeline_service.CreateTrainingPipelineRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateTrainingPipelineRequest): The request object. Request message for ``PipelineService.CreateTrainingPipeline``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to create the TrainingPipeline in. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - training_pipeline (:class:`~.gca_training_pipeline.TrainingPipeline`): + training_pipeline (google.cloud.aiplatform_v1beta1.types.TrainingPipeline): Required. The TrainingPipeline to create. + This corresponds to the ``training_pipeline`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -416,13 +431,13 @@ def create_training_pipeline( sent along with the request as metadata. Returns: - ~.gca_training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - ``upload`` - the Model to AI Platform, and evaluate the Model. + google.cloud.aiplatform_v1beta1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. """ # Create or coerce a protobuf request object. @@ -478,14 +493,15 @@ def get_training_pipeline( r"""Gets a TrainingPipeline. Args: - request (:class:`~.pipeline_service.GetTrainingPipelineRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetTrainingPipelineRequest): The request object. Request message for ``PipelineService.GetTrainingPipeline``. - name (:class:`str`): + name (str): Required. The name of the TrainingPipeline resource. Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -497,13 +513,13 @@ def get_training_pipeline( sent along with the request as metadata. Returns: - ~.training_pipeline.TrainingPipeline: - The TrainingPipeline orchestrates tasks associated with - training a Model. It always executes the training task, - and optionally may also export data from AI Platform's - Dataset which becomes the training input, - ``upload`` - the Model to AI Platform, and evaluate the Model. + google.cloud.aiplatform_v1beta1.types.TrainingPipeline: + The TrainingPipeline orchestrates tasks associated with training a Model. It + always executes the training task, and optionally may + also export data from AI Platform's Dataset which + becomes the training input, + ``upload`` + the Model to AI Platform, and evaluate the Model. """ # Create or coerce a protobuf request object. @@ -557,13 +573,14 @@ def list_training_pipelines( r"""Lists TrainingPipelines in a Location. Args: - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesRequest): The request object. Request message for ``PipelineService.ListTrainingPipelines``. - parent (:class:`str`): + parent (str): Required. The resource name of the Location to list the TrainingPipelines from. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -575,7 +592,7 @@ def list_training_pipelines( sent along with the request as metadata. Returns: - ~.pagers.ListTrainingPipelinesPager: + google.cloud.aiplatform_v1beta1.services.pipeline_service.pagers.ListTrainingPipelinesPager: Response message for ``PipelineService.ListTrainingPipelines`` @@ -640,14 +657,15 @@ def delete_training_pipeline( r"""Deletes a TrainingPipeline. Args: - request (:class:`~.pipeline_service.DeleteTrainingPipelineRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteTrainingPipelineRequest): The request object. Request message for ``PipelineService.DeleteTrainingPipeline``. - name (:class:`str`): + name (str): Required. The name of the TrainingPipeline resource to be deleted. Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -659,24 +677,22 @@ def delete_training_pipeline( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -750,14 +766,15 @@ def cancel_training_pipeline( is set to ``CANCELLED``. Args: - request (:class:`~.pipeline_service.CancelTrainingPipelineRequest`): + request (google.cloud.aiplatform_v1beta1.types.CancelTrainingPipelineRequest): The request object. Request message for ``PipelineService.CancelTrainingPipeline``. - name (:class:`str`): + name (str): Required. The name of the TrainingPipeline to cancel. Format: ``projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py index 98e5a51a17..1c8616e0a1 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/pagers.py @@ -25,7 +25,7 @@ class ListTrainingPipelinesPager: """A pager for iterating through ``list_training_pipelines`` requests. This class thinly wraps an initial - :class:`~.pipeline_service.ListTrainingPipelinesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse` object, and provides an ``__iter__`` method to iterate through its ``training_pipelines`` field. @@ -34,7 +34,7 @@ class ListTrainingPipelinesPager: through the ``training_pipelines`` field on the corresponding responses. - All the usual :class:`~.pipeline_service.ListTrainingPipelinesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -52,9 +52,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesRequest): The initial request object. - response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -87,7 +87,7 @@ class ListTrainingPipelinesAsyncPager: """A pager for iterating through ``list_training_pipelines`` requests. This class thinly wraps an initial - :class:`~.pipeline_service.ListTrainingPipelinesResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse` object, and provides an ``__aiter__`` method to iterate through its ``training_pipelines`` field. @@ -96,7 +96,7 @@ class ListTrainingPipelinesAsyncPager: through the ``training_pipelines`` field on the corresponding responses. - All the usual :class:`~.pipeline_service.ListTrainingPipelinesResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -116,9 +116,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.pipeline_service.ListTrainingPipelinesRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesRequest): The initial request object. - response (:class:`~.pipeline_service.ListTrainingPipelinesResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListTrainingPipelinesResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py index d9d71a892b..9d4610087a 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = PipelineServiceGrpcTransport _transport_registry["grpc_asyncio"] = PipelineServiceGrpcAsyncIOTransport - __all__ = ( "PipelineServiceTransport", "PipelineServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py index 25e8acb412..41123b8615 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/base.py @@ -76,10 +76,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py index fc6ca0087e..6aa6880fdb 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc.py @@ -64,6 +64,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -94,6 +95,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -110,6 +115,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -119,11 +129,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -153,6 +158,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -163,17 +172,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -197,7 +217,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -232,7 +252,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -243,13 +264,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_training_pipeline( diff --git a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py index 173c771eab..76f21faf50 100644 --- a/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/pipeline_service/transports/grpc_asyncio.py @@ -108,6 +108,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -139,12 +140,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -155,6 +160,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -164,11 +174,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -198,6 +203,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -208,14 +217,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -229,6 +248,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -248,13 +268,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_training_pipeline( diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py index bb58b0bfac..8eeae282f6 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/async_client.py @@ -77,6 +77,7 @@ class PredictionServiceAsyncClient: PredictionServiceClient.parse_common_location_path ) + from_service_account_info = PredictionServiceClient.from_service_account_info from_service_account_file = PredictionServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -155,17 +156,18 @@ async def predict( r"""Perform an online prediction. Args: - request (:class:`~.prediction_service.PredictRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.PredictRequest`): The request object. Request message for ``PredictionService.Predict``. endpoint (:class:`str`): Required. The name of the Endpoint requested to serve the prediction. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - instances (:class:`Sequence[~.struct.Value]`): + instances (:class:`Sequence[google.protobuf.struct_pb2.Value]`): Required. The instances that are the input to the prediction call. A DeployedModel may have an upper limit on the number of instances it supports per request, and @@ -177,16 +179,18 @@ async def predict( [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (:class:`~.struct.Value`): + parameters (:class:`google.protobuf.struct_pb2.Value`): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + This corresponds to the ``parameters`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -198,7 +202,7 @@ async def predict( sent along with the request as metadata. Returns: - ~.prediction_service.PredictResponse: + google.cloud.aiplatform_v1beta1.types.PredictResponse: Response message for ``PredictionService.Predict``. @@ -272,17 +276,18 @@ async def explain( explanation_spec. Args: - request (:class:`~.prediction_service.ExplainRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ExplainRequest`): The request object. Request message for ``PredictionService.Explain``. endpoint (:class:`str`): Required. The name of the Endpoint requested to serve the explanation. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - instances (:class:`Sequence[~.struct.Value]`): + instances (:class:`Sequence[google.protobuf.struct_pb2.Value]`): Required. The instances that are the input to the explanation call. A DeployedModel may have an upper limit on the number of instances it supports per @@ -294,16 +299,18 @@ async def explain( [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (:class:`~.struct.Value`): + parameters (:class:`google.protobuf.struct_pb2.Value`): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + This corresponds to the ``parameters`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -311,6 +318,7 @@ async def explain( If specified, this ExplainRequest will be served by the chosen DeployedModel, overriding ``Endpoint.traffic_split``. + This corresponds to the ``deployed_model_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -322,7 +330,7 @@ async def explain( sent along with the request as metadata. Returns: - ~.prediction_service.ExplainResponse: + google.cloud.aiplatform_v1beta1.types.ExplainResponse: Response message for ``PredictionService.Explain``. diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py index 9a5976d697..89b90b5c12 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/client.py @@ -113,6 +113,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -125,7 +141,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + PredictionServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -233,10 +249,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.PredictionServiceTransport]): The + transport (Union[str, PredictionServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -272,21 +288,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -329,7 +341,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -348,17 +360,18 @@ def predict( r"""Perform an online prediction. Args: - request (:class:`~.prediction_service.PredictRequest`): + request (google.cloud.aiplatform_v1beta1.types.PredictRequest): The request object. Request message for ``PredictionService.Predict``. - endpoint (:class:`str`): + endpoint (str): Required. The name of the Endpoint requested to serve the prediction. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - instances (:class:`Sequence[~.struct.Value]`): + instances (Sequence[google.protobuf.struct_pb2.Value]): Required. The instances that are the input to the prediction call. A DeployedModel may have an upper limit on the number of instances it supports per request, and @@ -370,16 +383,18 @@ def predict( [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (:class:`~.struct.Value`): + parameters (google.protobuf.struct_pb2.Value): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + This corresponds to the ``parameters`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -391,7 +406,7 @@ def predict( sent along with the request as metadata. Returns: - ~.prediction_service.PredictResponse: + google.cloud.aiplatform_v1beta1.types.PredictResponse: Response message for ``PredictionService.Predict``. @@ -466,17 +481,18 @@ def explain( explanation_spec. Args: - request (:class:`~.prediction_service.ExplainRequest`): + request (google.cloud.aiplatform_v1beta1.types.ExplainRequest): The request object. Request message for ``PredictionService.Explain``. - endpoint (:class:`str`): + endpoint (str): Required. The name of the Endpoint requested to serve the explanation. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + This corresponds to the ``endpoint`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - instances (:class:`Sequence[~.struct.Value]`): + instances (Sequence[google.protobuf.struct_pb2.Value]): Required. The instances that are the input to the explanation call. A DeployedModel may have an upper limit on the number of instances it supports per @@ -488,23 +504,26 @@ def explain( [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. + This corresponds to the ``instances`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - parameters (:class:`~.struct.Value`): + parameters (google.protobuf.struct_pb2.Value): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + This corresponds to the ``parameters`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - deployed_model_id (:class:`str`): + deployed_model_id (str): If specified, this ExplainRequest will be served by the chosen DeployedModel, overriding ``Endpoint.traffic_split``. + This corresponds to the ``deployed_model_id`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -516,7 +535,7 @@ def explain( sent along with the request as metadata. Returns: - ~.prediction_service.ExplainResponse: + google.cloud.aiplatform_v1beta1.types.ExplainResponse: Response message for ``PredictionService.Explain``. diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py index 7eb32ea86d..9ec1369a05 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/__init__.py @@ -28,7 +28,6 @@ _transport_registry["grpc"] = PredictionServiceGrpcTransport _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport - __all__ = ( "PredictionServiceTransport", "PredictionServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py index f2f7a028cc..0c82f7d83c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/base.py @@ -69,10 +69,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py index fbfbabef1b..53345710c6 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc.py @@ -57,6 +57,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -87,6 +88,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -103,6 +108,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -112,11 +122,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -146,6 +151,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -156,14 +165,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] @@ -190,7 +209,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -225,7 +244,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property diff --git a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py index 237fa8a75c..e1493acc9c 100644 --- a/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/prediction_service/transports/grpc_asyncio.py @@ -101,6 +101,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -132,12 +133,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -148,6 +153,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -157,11 +167,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -191,6 +196,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -201,14 +210,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py index c693126d4c..c85d91436e 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/async_client.py @@ -95,6 +95,7 @@ class SpecialistPoolServiceAsyncClient: SpecialistPoolServiceClient.parse_common_location_path ) + from_service_account_info = SpecialistPoolServiceClient.from_service_account_info from_service_account_file = SpecialistPoolServiceClient.from_service_account_file from_service_account_json = from_service_account_file @@ -173,19 +174,21 @@ async def create_specialist_pool( r"""Creates a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.CreateSpecialistPoolRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.CreateSpecialistPoolRequest`): The request object. Request message for ``SpecialistPoolService.CreateSpecialistPool``. parent (:class:`str`): Required. The parent Project name for the new SpecialistPool. The form is ``projects/{project}/locations/{location}``. + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + specialist_pool (:class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool`): Required. The SpecialistPool to create. + This corresponds to the ``specialist_pool`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -197,19 +200,17 @@ async def create_specialist_pool( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_specialist_pool.SpecialistPool`: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. """ # Create or coerce a protobuf request object. @@ -272,7 +273,7 @@ async def get_specialist_pool( r"""Gets a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.GetSpecialistPoolRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.GetSpecialistPoolRequest`): The request object. Request message for ``SpecialistPoolService.GetSpecialistPool``. name (:class:`str`): @@ -280,6 +281,7 @@ async def get_specialist_pool( form is ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -291,7 +293,7 @@ async def get_specialist_pool( sent along with the request as metadata. Returns: - ~.specialist_pool.SpecialistPool: + google.cloud.aiplatform_v1beta1.types.SpecialistPool: SpecialistPool represents customers' own workforce to work on their data labeling jobs. It includes a group of @@ -355,13 +357,14 @@ async def list_specialist_pools( r"""Lists SpecialistPools in a Location. Args: - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsRequest`): The request object. Request message for ``SpecialistPoolService.ListSpecialistPools``. parent (:class:`str`): Required. The name of the SpecialistPool's parent resource. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -373,7 +376,7 @@ async def list_specialist_pools( sent along with the request as metadata. Returns: - ~.pagers.ListSpecialistPoolsAsyncPager: + google.cloud.aiplatform_v1beta1.services.specialist_pool_service.pagers.ListSpecialistPoolsAsyncPager: Response message for ``SpecialistPoolService.ListSpecialistPools``. @@ -438,13 +441,14 @@ async def delete_specialist_pool( in the pool. Args: - request (:class:`~.specialist_pool_service.DeleteSpecialistPoolRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.DeleteSpecialistPoolRequest`): The request object. Request message for ``SpecialistPoolService.DeleteSpecialistPool``. name (:class:`str`): Required. The resource name of the SpecialistPool to delete. Format: ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -456,24 +460,22 @@ async def delete_specialist_pool( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -535,18 +537,20 @@ async def update_specialist_pool( r"""Updates a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.UpdateSpecialistPoolRequest`): + request (:class:`google.cloud.aiplatform_v1beta1.types.UpdateSpecialistPoolRequest`): The request object. Request message for ``SpecialistPoolService.UpdateSpecialistPool``. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + specialist_pool (:class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool`): Required. The SpecialistPool which replaces the resource on the server. + This corresponds to the ``specialist_pool`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): Required. The update mask applies to the resource. + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -558,19 +562,17 @@ async def update_specialist_pool( sent along with the request as metadata. Returns: - ~.operation_async.AsyncOperation: + google.api_core.operation_async.AsyncOperation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_specialist_pool.SpecialistPool`: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. """ # Create or coerce a protobuf request object. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py index efc19eca12..6018955006 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/client.py @@ -125,6 +125,22 @@ def _get_default_mtls_endpoint(api_endpoint): DEFAULT_ENDPOINT ) + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + SpecialistPoolServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + @classmethod def from_service_account_file(cls, filename: str, *args, **kwargs): """Creates an instance of this client using the provided credentials @@ -137,7 +153,7 @@ def from_service_account_file(cls, filename: str, *args, **kwargs): kwargs: Additional arguments to pass to the constructor. Returns: - {@api.name}: The constructed client. + SpecialistPoolServiceClient: The constructed client. """ credentials = service_account.Credentials.from_service_account_file(filename) kwargs["credentials"] = credentials @@ -245,10 +261,10 @@ def __init__( credentials identify the application to the service; if none are specified, the client will attempt to ascertain the credentials from the environment. - transport (Union[str, ~.SpecialistPoolServiceTransport]): The + transport (Union[str, SpecialistPoolServiceTransport]): The transport to use. If set to None, a transport is chosen automatically. - client_options (client_options_lib.ClientOptions): Custom options for the + client_options (google.api_core.client_options.ClientOptions): Custom options for the client. It won't take effect if a ``transport`` instance is provided. (1) The ``api_endpoint`` property can be used to override the default endpoint provided by the client. GOOGLE_API_USE_MTLS_ENDPOINT @@ -284,21 +300,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -341,7 +353,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -359,19 +371,21 @@ def create_specialist_pool( r"""Creates a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.CreateSpecialistPoolRequest`): + request (google.cloud.aiplatform_v1beta1.types.CreateSpecialistPoolRequest): The request object. Request message for ``SpecialistPoolService.CreateSpecialistPool``. - parent (:class:`str`): + parent (str): Required. The parent Project name for the new SpecialistPool. The form is ``projects/{project}/locations/{location}``. + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + specialist_pool (google.cloud.aiplatform_v1beta1.types.SpecialistPool): Required. The SpecialistPool to create. + This corresponds to the ``specialist_pool`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -383,19 +397,17 @@ def create_specialist_pool( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_specialist_pool.SpecialistPool`: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. """ # Create or coerce a protobuf request object. @@ -459,14 +471,15 @@ def get_specialist_pool( r"""Gets a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.GetSpecialistPoolRequest`): + request (google.cloud.aiplatform_v1beta1.types.GetSpecialistPoolRequest): The request object. Request message for ``SpecialistPoolService.GetSpecialistPool``. - name (:class:`str`): + name (str): Required. The name of the SpecialistPool resource. The form is ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}``. + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -478,7 +491,7 @@ def get_specialist_pool( sent along with the request as metadata. Returns: - ~.specialist_pool.SpecialistPool: + google.cloud.aiplatform_v1beta1.types.SpecialistPool: SpecialistPool represents customers' own workforce to work on their data labeling jobs. It includes a group of @@ -543,13 +556,14 @@ def list_specialist_pools( r"""Lists SpecialistPools in a Location. Args: - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsRequest): The request object. Request message for ``SpecialistPoolService.ListSpecialistPools``. - parent (:class:`str`): + parent (str): Required. The name of the SpecialistPool's parent resource. Format: ``projects/{project}/locations/{location}`` + This corresponds to the ``parent`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -561,7 +575,7 @@ def list_specialist_pools( sent along with the request as metadata. Returns: - ~.pagers.ListSpecialistPoolsPager: + google.cloud.aiplatform_v1beta1.services.specialist_pool_service.pagers.ListSpecialistPoolsPager: Response message for ``SpecialistPoolService.ListSpecialistPools``. @@ -627,13 +641,14 @@ def delete_specialist_pool( in the pool. Args: - request (:class:`~.specialist_pool_service.DeleteSpecialistPoolRequest`): + request (google.cloud.aiplatform_v1beta1.types.DeleteSpecialistPoolRequest): The request object. Request message for ``SpecialistPoolService.DeleteSpecialistPool``. - name (:class:`str`): + name (str): Required. The resource name of the SpecialistPool to delete. Format: ``projects/{project}/locations/{location}/specialistPools/{specialist_pool}`` + This corresponds to the ``name`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -645,24 +660,22 @@ def delete_specialist_pool( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.empty.Empty`: A generic empty message that - you can re-use to avoid defining duplicated empty - messages in your APIs. A typical example is to use it as - the request or the response type of an API method. For - instance: + The result type for the operation will be :class:`google.protobuf.empty_pb2.Empty` A generic empty message that you can re-use to avoid defining duplicated + empty messages in your APIs. A typical example is to + use it as the request or the response type of an API + method. For instance: - :: + service Foo { + rpc Bar(google.protobuf.Empty) returns + (google.protobuf.Empty); - service Foo { - rpc Bar(google.protobuf.Empty) returns (google.protobuf.Empty); - } + } - The JSON representation for ``Empty`` is empty JSON - object ``{}``. + The JSON representation for Empty is empty JSON + object {}. """ # Create or coerce a protobuf request object. @@ -725,18 +738,20 @@ def update_specialist_pool( r"""Updates a SpecialistPool. Args: - request (:class:`~.specialist_pool_service.UpdateSpecialistPoolRequest`): + request (google.cloud.aiplatform_v1beta1.types.UpdateSpecialistPoolRequest): The request object. Request message for ``SpecialistPoolService.UpdateSpecialistPool``. - specialist_pool (:class:`~.gca_specialist_pool.SpecialistPool`): + specialist_pool (google.cloud.aiplatform_v1beta1.types.SpecialistPool): Required. The SpecialistPool which replaces the resource on the server. + This corresponds to the ``specialist_pool`` field on the ``request`` instance; if ``request`` is provided, this should not be set. - update_mask (:class:`~.field_mask.FieldMask`): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. + This corresponds to the ``update_mask`` field on the ``request`` instance; if ``request`` is provided, this should not be set. @@ -748,19 +763,17 @@ def update_specialist_pool( sent along with the request as metadata. Returns: - ~.ga_operation.Operation: + google.api_core.operation.Operation: An object representing a long-running operation. - The result type for the operation will be - :class:`~.gca_specialist_pool.SpecialistPool`: - SpecialistPool represents customers' own workforce to - work on their data labeling jobs. It includes a group of - specialist managers who are responsible for managing the - labelers in this pool as well as customers' data - labeling jobs associated with this pool. Customers - create specialist pool as well as start data labeling - jobs on Cloud, managers and labelers work with the jobs - using CrowdCompute console. + The result type for the operation will be :class:`google.cloud.aiplatform_v1beta1.types.SpecialistPool` SpecialistPool represents customers' own workforce to work on their data + labeling jobs. It includes a group of specialist + managers who are responsible for managing the + labelers in this pool as well as customers' data + labeling jobs associated with this pool. Customers + create specialist pool as well as start data labeling + jobs on Cloud, managers and labelers work with the + jobs using CrowdCompute console. """ # Create or coerce a protobuf request object. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py index ff2d84ac74..61a5f5de57 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/pagers.py @@ -25,7 +25,7 @@ class ListSpecialistPoolsPager: """A pager for iterating through ``list_specialist_pools`` requests. This class thinly wraps an initial - :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse` object, and provides an ``__iter__`` method to iterate through its ``specialist_pools`` field. @@ -34,7 +34,7 @@ class ListSpecialistPoolsPager: through the ``specialist_pools`` field on the corresponding responses. - All the usual :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -52,9 +52,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsRequest): The initial request object. - response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. @@ -87,7 +87,7 @@ class ListSpecialistPoolsAsyncPager: """A pager for iterating through ``list_specialist_pools`` requests. This class thinly wraps an initial - :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` object, and + :class:`google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse` object, and provides an ``__aiter__`` method to iterate through its ``specialist_pools`` field. @@ -96,7 +96,7 @@ class ListSpecialistPoolsAsyncPager: through the ``specialist_pools`` field on the corresponding responses. - All the usual :class:`~.specialist_pool_service.ListSpecialistPoolsResponse` + All the usual :class:`google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse` attributes are available on the pager. If multiple requests are made, only the most recent response is retained, and thus used for attribute lookup. """ @@ -116,9 +116,9 @@ def __init__( Args: method (Callable): The method that was originally called, and which instantiated this pager. - request (:class:`~.specialist_pool_service.ListSpecialistPoolsRequest`): + request (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsRequest): The initial request object. - response (:class:`~.specialist_pool_service.ListSpecialistPoolsResponse`): + response (google.cloud.aiplatform_v1beta1.types.ListSpecialistPoolsResponse): The initial response object. metadata (Sequence[Tuple[str, str]]): Strings which should be sent along with the request as metadata. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py index 711f7fd1cc..1bb2fbf22a 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/__init__.py @@ -30,7 +30,6 @@ _transport_registry["grpc"] = SpecialistPoolServiceGrpcTransport _transport_registry["grpc_asyncio"] = SpecialistPoolServiceGrpcAsyncIOTransport - __all__ = ( "SpecialistPoolServiceTransport", "SpecialistPoolServiceGrpcTransport", diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py index a39c2f1f71..f1af058030 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/base.py @@ -72,10 +72,10 @@ def __init__( scope (Optional[Sequence[str]]): A list of scopes. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. """ # Save the hostname. Default to port 443 (HTTPS) if none is specified. diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py index 2fc7e0881b..61c82508b9 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc.py @@ -65,6 +65,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -95,6 +96,10 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -111,6 +116,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -120,11 +130,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -154,6 +159,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -164,17 +173,28 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._stubs = {} # type: Dict[str, Callable] + self._operations_client = None # Run the base constructor. super().__init__( @@ -198,7 +218,7 @@ def create_channel( ) -> grpc.Channel: """Create and return a gRPC channel object. Args: - address (Optionsl[str]): The host for the channel to use. + address (Optional[str]): The host for the channel to use. credentials (Optional[~.Credentials]): The authorization credentials to attach to requests. These credentials identify this application to the service. If @@ -233,7 +253,8 @@ def create_channel( @property def grpc_channel(self) -> grpc.Channel: - """Return the channel designed to connect to this service.""" + """Return the channel designed to connect to this service. + """ return self._grpc_channel @property @@ -244,13 +265,11 @@ def operations_client(self) -> operations_v1.OperationsClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsClient( - self.grpc_channel - ) + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient(self.grpc_channel) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_specialist_pool( diff --git a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py index a6d3b045e6..a71d380b5b 100644 --- a/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/specialist_pool_service/transports/grpc_asyncio.py @@ -109,6 +109,7 @@ def __init__( api_mtls_endpoint: str = None, client_cert_source: Callable[[], Tuple[bytes, bytes]] = None, ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id=None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -140,12 +141,16 @@ def __init__( ``api_mtls_endpoint`` is None. ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure mutual TLS channel. It is + ignored if ``channel`` or ``ssl_channel_credentials`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. - client_info (google.api_core.gapic_v1.client_info.ClientInfo): - The client info used to send a user-agent string along with - API requests. If ``None``, then default info will be used. - Generally, you only need to set this if you're developing + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing your own client library. Raises: @@ -156,6 +161,11 @@ def __init__( """ self._ssl_channel_credentials = ssl_channel_credentials + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + if channel: # Sanity check: Ensure that channel and credentials are not both # provided. @@ -165,11 +175,6 @@ def __init__( self._grpc_channel = channel self._ssl_channel_credentials = None elif api_mtls_endpoint: - warnings.warn( - "api_mtls_endpoint and client_cert_source are deprecated", - DeprecationWarning, - ) - host = ( api_mtls_endpoint if ":" in api_mtls_endpoint @@ -199,6 +204,10 @@ def __init__( ssl_credentials=ssl_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) self._ssl_channel_credentials = ssl_credentials else: @@ -209,14 +218,24 @@ def __init__( scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id ) + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + # create a new channel. The provided one is ignored. self._grpc_channel = type(self).create_channel( host, credentials=credentials, credentials_file=credentials_file, - ssl_credentials=ssl_channel_credentials, + ssl_credentials=self._ssl_channel_credentials, scopes=scopes or self.AUTH_SCOPES, quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) # Run the base constructor. @@ -230,6 +249,7 @@ def __init__( ) self._stubs = {} + self._operations_client = None @property def grpc_channel(self) -> aio.Channel: @@ -249,13 +269,13 @@ def operations_client(self) -> operations_v1.OperationsAsyncClient: client. """ # Sanity check: Only create a new client if we do not already have one. - if "operations_client" not in self.__dict__: - self.__dict__["operations_client"] = operations_v1.OperationsAsyncClient( + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( self.grpc_channel ) # Return the client from cache. - return self.__dict__["operations_client"] + return self._operations_client @property def create_specialist_pool( diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index c668a7be98..ca848c7c54 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -19,6 +19,7 @@ from .annotation import Annotation from .annotation_spec import AnnotationSpec from .completion_stats import CompletionStats +from .encryption_spec import EncryptionSpec from .explanation_metadata import ExplanationMetadata from .explanation import ( Explanation, @@ -31,6 +32,8 @@ XraiAttribution, SmoothGradConfig, FeatureNoiseSigma, + ExplanationSpecOverride, + ExplanationMetadataOverride, ) from .io import ( GcsSource, @@ -217,12 +220,13 @@ UpdateSpecialistPoolOperationMetadata, ) - __all__ = ( + "AcceleratorType", "UserActionReference", "Annotation", "AnnotationSpec", "CompletionStats", + "EncryptionSpec", "ExplanationMetadata", "Explanation", "ModelExplanation", @@ -234,11 +238,14 @@ "XraiAttribution", "SmoothGradConfig", "FeatureNoiseSigma", + "ExplanationSpecOverride", + "ExplanationMetadataOverride", "GcsSource", "GcsDestination", "BigQuerySource", "BigQueryDestination", "ContainerRegistryDestination", + "JobState", "MachineSpec", "DedicatedResources", "AutomaticResources", @@ -270,6 +277,7 @@ "PredictSchemata", "ModelContainerSpec", "Port", + "PipelineState", "TrainingPipeline", "InputDataConfig", "FractionSplit", diff --git a/google/cloud/aiplatform_v1beta1/types/annotation.py b/google/cloud/aiplatform_v1beta1/types/annotation.py index 7734fcc512..74bb5eac98 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation.py @@ -46,22 +46,22 @@ class Annotation(proto.Message): note that the chosen schema must be consistent with the parent Dataset's ``metadata``. - payload (~.struct.Value): + payload (google.protobuf.struct_pb2.Value): Required. The schema of the payload can be found in ``payload_schema``. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Annotation was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Annotation was last updated. etag (str): Optional. Used to perform a consistent read- odify-write updates. If not set, a blind "overwrite" update happens. - annotation_source (~.user_action_reference.UserActionReference): + annotation_source (google.cloud.aiplatform_v1beta1.types.UserActionReference): Output only. The source of the Annotation. - labels (Sequence[~.annotation.Annotation.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Annotation.LabelsEntry]): Optional. The labels with user-defined metadata to organize your Annotations. diff --git a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py index a5a4b3d489..9d35539a5b 100644 --- a/google/cloud/aiplatform_v1beta1/types/annotation_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/annotation_spec.py @@ -39,10 +39,10 @@ class AnnotationSpec(proto.Message): AnnotationSpec. The name can be up to 128 characters long and can be consist of any UTF-8 characters. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this AnnotationSpec was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when AnnotationSpec was last updated. etag (str): diff --git a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py index 625bf83155..a51de2f9a2 100644 --- a/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py +++ b/google/cloud/aiplatform_v1beta1/types/batch_prediction_job.py @@ -21,6 +21,7 @@ from google.cloud.aiplatform_v1beta1.types import ( completion_stats as gca_completion_stats, ) +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state @@ -60,20 +61,20 @@ class BatchPredictionJob(proto.Message): same ancestor Location. Starting this job has no impact on any existing deployments of the Model and their resources. - input_config (~.batch_prediction_job.BatchPredictionJob.InputConfig): + input_config (google.cloud.aiplatform_v1beta1.types.BatchPredictionJob.InputConfig): Required. Input configuration of the instances on which predictions are performed. The schema of any single instance may be specified via the [Model's][google.cloud.aiplatform.v1beta1.BatchPredictionJob.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. - model_parameters (~.struct.Value): + model_parameters (google.protobuf.struct_pb2.Value): The parameters that govern the predictions. The schema of the parameters may be specified via the [Model's][google.cloud.aiplatform.v1beta1.BatchPredictionJob.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. - output_config (~.batch_prediction_job.BatchPredictionJob.OutputConfig): + output_config (google.cloud.aiplatform_v1beta1.types.BatchPredictionJob.OutputConfig): Required. The Configuration specifying where output predictions should be written. The schema of any single prediction may be specified as a concatenation of @@ -82,73 +83,73 @@ class BatchPredictionJob(proto.Message): ``instance_schema_uri`` and ``prediction_schema_uri``. - dedicated_resources (~.machine_resources.BatchDedicatedResources): + dedicated_resources (google.cloud.aiplatform_v1beta1.types.BatchDedicatedResources): The config of resources used by the Model during the batch prediction. If the Model ``supports`` DEDICATED_RESOURCES this config may be provided (and the job will use these resources), if the Model doesn't support AUTOMATIC_RESOURCES, this config must be provided. - manual_batch_tuning_parameters (~.gca_manual_batch_tuning_parameters.ManualBatchTuningParameters): + manual_batch_tuning_parameters (google.cloud.aiplatform_v1beta1.types.ManualBatchTuningParameters): Immutable. Parameters configuring the batch behavior. Currently only applicable when ``dedicated_resources`` are used (in other cases AI Platform does the tuning itself). generate_explanation (bool): - Generate explanation along with the batch prediction - results. + Generate explanation with the batch prediction results. - When it's true, the batch prediction output will change - based on the [output - format][BatchPredictionJob.output_config.predictions_format]: + When set to ``true``, the batch prediction output changes + based on the ``predictions_format`` field of the + ``BatchPredictionJob.output_config`` + object: - - ``bigquery``: output will include a column named + - ``bigquery``: output includes a column named ``explanation``. The value is a struct that conforms to the ``Explanation`` object. - - ``jsonl``: The JSON objects on each line will include an + - ``jsonl``: The JSON objects on each line include an additional entry keyed ``explanation``. The value of the entry is a JSON object that conforms to the ``Explanation`` object. - ``csv``: Generating explanations for CSV format is not supported. - explanation_spec (~.explanation.ExplanationSpec): + + If this field is set to true, the + ``Model.explanation_spec`` + must be populated. + explanation_spec (google.cloud.aiplatform_v1beta1.types.ExplanationSpec): Explanation configuration for this BatchPredictionJob. Can - only be specified if + be specified only if ``generate_explanation`` - is set to ``true``. It's invalid to specified it with - generate_explanation set to false or unset. + is set to ``true``. This value overrides the value of ``Model.explanation_spec``. All fields of ``explanation_spec`` - are optional in the request. If a field of + are optional in the request. If a field of the ``explanation_spec`` - is not populated, the value of the same field of - ``Model.explanation_spec`` - is inherited. The corresponding + object is not populated, the corresponding field of the ``Model.explanation_spec`` - must be populated, otherwise explanation for this Model is - not allowed. - output_info (~.batch_prediction_job.BatchPredictionJob.OutputInfo): + object is inherited. + output_info (google.cloud.aiplatform_v1beta1.types.BatchPredictionJob.OutputInfo): Output only. Information further describing the output of this job. - state (~.job_state.JobState): + state (google.cloud.aiplatform_v1beta1.types.JobState): Output only. The detailed state of the job. - error (~.status.Status): + error (google.rpc.status_pb2.Status): Output only. Only populated when the job's state is JOB_STATE_FAILED or JOB_STATE_CANCELLED. - partial_failures (Sequence[~.status.Status]): + partial_failures (Sequence[google.rpc.status_pb2.Status]): Output only. Partial failures encountered. For example, single files that can't be read. This field never exceeds 20 entries. Status details fields contain standard GCP error details. - resources_consumed (~.machine_resources.ResourcesConsumed): + resources_consumed (google.cloud.aiplatform_v1beta1.types.ResourcesConsumed): Output only. Information about resources that had been consumed by this job. Provided in real time at best effort basis, as well as a final @@ -156,23 +157,23 @@ class BatchPredictionJob(proto.Message): Note: This field currently may be not populated for batch predictions that use AutoML Models. - completion_stats (~.gca_completion_stats.CompletionStats): + completion_stats (google.cloud.aiplatform_v1beta1.types.CompletionStats): Output only. Statistics on completed and failed prediction instances. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the BatchPredictionJob was created. - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the BatchPredictionJob for the first time entered the ``JOB_STATE_RUNNING`` state. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the BatchPredictionJob entered any of the following states: ``JOB_STATE_SUCCEEDED``, ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the BatchPredictionJob was most recently updated. - labels (Sequence[~.batch_prediction_job.BatchPredictionJob.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.BatchPredictionJob.LabelsEntry]): The labels with user-defined metadata to organize BatchPredictionJobs. Label keys and values can be no longer than 64 @@ -182,6 +183,11 @@ class BatchPredictionJob(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key options for a + BatchPredictionJob. If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. """ class InputConfig(proto.Message): @@ -193,10 +199,10 @@ class InputConfig(proto.Message): expressed via any of them. Attributes: - gcs_source (~.io.GcsSource): - The Google Cloud Storage location for the - input instances. - bigquery_source (~.io.BigQuerySource): + gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): + The Cloud Storage location for the input + instances. + bigquery_source (google.cloud.aiplatform_v1beta1.types.BigQuerySource): The BigQuery location of the input table. The schema of the table should be in the format described by the given context OpenAPI Schema, @@ -229,10 +235,10 @@ class OutputConfig(proto.Message): any of them. Attributes: - gcs_destination (~.io.GcsDestination): - The Google Cloud Storage location of the directory where the - output is to be written to. In the given directory a new - directory is created. Its name is + gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): + The Cloud Storage location of the directory where the output + is to be written to. In the given directory a new directory + is created. Its name is ``prediction--``, where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. Inside of it files ``predictions_0001.``, @@ -256,7 +262,7 @@ class OutputConfig(proto.Message): per their schema, followed by an additional ``error`` field which as value has ```google.rpc.Status`` `__ containing only ``code`` and ``message`` fields. - bigquery_destination (~.io.BigQueryDestination): + bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): The BigQuery project location where the output is to be written to. In the given project a new dataset is created with name @@ -305,8 +311,8 @@ class OutputInfo(proto.Message): Attributes: gcs_output_directory (str): - Output only. The full path of the Google - Cloud Storage directory created, into which the + Output only. The full path of the Cloud + Storage directory created, into which the prediction output is written. bigquery_output_dataset (str): Output only. The path of the BigQuery dataset created, in @@ -378,5 +384,9 @@ class OutputInfo(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=19) + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + ) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/custom_job.py b/google/cloud/aiplatform_v1beta1/types/custom_job.py index 2d8745538c..a1674327ef 100644 --- a/google/cloud/aiplatform_v1beta1/types/custom_job.py +++ b/google/cloud/aiplatform_v1beta1/types/custom_job.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import machine_resources @@ -53,27 +54,27 @@ class CustomJob(proto.Message): Required. The display name of the CustomJob. The name can be up to 128 characters long and can be consist of any UTF-8 characters. - job_spec (~.custom_job.CustomJobSpec): + job_spec (google.cloud.aiplatform_v1beta1.types.CustomJobSpec): Required. Job spec. - state (~.job_state.JobState): + state (google.cloud.aiplatform_v1beta1.types.JobState): Output only. The detailed state of the job. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the CustomJob was created. - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the CustomJob for the first time entered the ``JOB_STATE_RUNNING`` state. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the CustomJob entered any of the following states: ``JOB_STATE_SUCCEEDED``, ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the CustomJob was most recently updated. - error (~.status.Status): + error (google.rpc.status_pb2.Status): Output only. Only populated when job's state is ``JOB_STATE_FAILED`` or ``JOB_STATE_CANCELLED``. - labels (Sequence[~.custom_job.CustomJob.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.CustomJob.LabelsEntry]): The labels with user-defined metadata to organize CustomJobs. Label keys and values can be no longer than 64 @@ -83,6 +84,11 @@ class CustomJob(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key options for a + CustomJob. If this is set, then all resources + created by the CustomJob will be encrypted with + the provided encryption key. """ name = proto.Field(proto.STRING, number=1) @@ -105,48 +111,51 @@ class CustomJob(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=11) + encryption_spec = proto.Field( + proto.MESSAGE, number=12, message=gca_encryption_spec.EncryptionSpec, + ) + class CustomJobSpec(proto.Message): r"""Represents the spec of a CustomJob. Attributes: - worker_pool_specs (Sequence[~.custom_job.WorkerPoolSpec]): + worker_pool_specs (Sequence[google.cloud.aiplatform_v1beta1.types.WorkerPoolSpec]): Required. The spec of the worker pools including machine type and Docker image. - scheduling (~.custom_job.Scheduling): + scheduling (google.cloud.aiplatform_v1beta1.types.Scheduling): Scheduling options for a CustomJob. service_account (str): Specifies the service account for workload run-as account. Users submitting jobs must have - act-as permission on this run-as account. + act-as permission on this run-as account. If + unspecified, the AI Platform Custom Code Service + Agent for the CustomJob's project is used. network (str): The full name of the Compute Engine `network `__ to which the Job should be peered. For example, - projects/12345/global/networks/myVPC. - - [Format](https://cloud.google.com/compute/docs/reference/rest/v1/networks/insert) - is of the form projects/{project}/global/networks/{network}. - Where {project} is a project number, as in '12345', and - {network} is network name. + ``projects/12345/global/networks/myVPC``. + `Format `__ + is of the form + ``projects/{project}/global/networks/{network}``. Where + {project} is a project number, as in ``12345``, and + {network} is a network name. Private services access must already be configured for the network. If left unspecified, the job is not peered with any network. - base_output_directory (~.io.GcsDestination): - The Google Cloud Storage location to store the output of - this CustomJob or HyperparameterTuningJob. For - HyperparameterTuningJob, - ``base_output_directory`` - of each child CustomJob backing a Trial is set to a - subdirectory of name - ``id`` under parent - HyperparameterTuningJob's - - ``base_output_directory``. - - Following AI Platform environment variables will be passed - to containers or python modules when this field is set: + base_output_directory (google.cloud.aiplatform_v1beta1.types.GcsDestination): + The Cloud Storage location to store the output of this + CustomJob or HyperparameterTuningJob. For + HyperparameterTuningJob, the baseOutputDirectory of each + child CustomJob backing a Trial is set to a subdirectory of + name ``id`` under + its parent HyperparameterTuningJob's baseOutputDirectory. + + The following AI Platform environment variables will be + passed to containers or python modules when this field is + set: For CustomJob: @@ -185,17 +194,17 @@ class WorkerPoolSpec(proto.Message): r"""Represents the spec of a worker pool in a job. Attributes: - container_spec (~.custom_job.ContainerSpec): + container_spec (google.cloud.aiplatform_v1beta1.types.ContainerSpec): The custom container task. - python_package_spec (~.custom_job.PythonPackageSpec): + python_package_spec (google.cloud.aiplatform_v1beta1.types.PythonPackageSpec): The Python packaged task. - machine_spec (~.machine_resources.MachineSpec): - Required. Immutable. The specification of a + machine_spec (google.cloud.aiplatform_v1beta1.types.MachineSpec): + Optional. Immutable. The specification of a single machine. replica_count (int): - Required. The number of worker replicas to + Optional. The number of worker replicas to use for this worker pool. - disk_spec (~.machine_resources.DiskSpec): + disk_spec (google.cloud.aiplatform_v1beta1.types.DiskSpec): Disk spec. """ @@ -280,7 +289,7 @@ class Scheduling(proto.Message): jobs. Attributes: - timeout (~.duration.Duration): + timeout (google.protobuf.duration_pb2.Duration): The maximum job running time. The default is 7 days. restart_job_on_worker_restart (bool): diff --git a/google/cloud/aiplatform_v1beta1/types/data_item.py b/google/cloud/aiplatform_v1beta1/types/data_item.py index e43a944d94..eff2516bda 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_item.py +++ b/google/cloud/aiplatform_v1beta1/types/data_item.py @@ -35,13 +35,13 @@ class DataItem(proto.Message): name (str): Output only. The resource name of the DataItem. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this DataItem was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this DataItem was last updated. - labels (Sequence[~.data_item.DataItem.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.DataItem.LabelsEntry]): Optional. The labels with user-defined metadata to organize your DataItems. Label keys and values can be no longer than 64 @@ -56,7 +56,7 @@ class DataItem(proto.Message): and examples of labels. System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. - payload (~.struct.Value): + payload (google.protobuf.struct_pb2.Value): Required. The data that the DataItem represents (for example, an image or a text snippet). The schema of the payload is stored in the parent Dataset's [metadata diff --git a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py index af1bcdd871..dc06549ac4 100644 --- a/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py +++ b/google/cloud/aiplatform_v1beta1/types/data_labeling_job.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import job_state from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -54,7 +55,7 @@ class DataLabelingJob(proto.Message): Required. Dataset resource names. Right now we only support labeling from a single Dataset. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` - annotation_labels (Sequence[~.data_labeling_job.DataLabelingJob.AnnotationLabelsEntry]): + annotation_labels (Sequence[google.cloud.aiplatform_v1beta1.types.DataLabelingJob.AnnotationLabelsEntry]): Labels to assign to annotations generated by this DataLabelingJob. Label keys and values can be no longer than 64 @@ -81,29 +82,29 @@ class DataLabelingJob(proto.Message): https://storage.googleapis.com/google-cloud- aiplatform bucket in the /schema/datalabelingjob/inputs/ folder. - inputs (~.struct.Value): + inputs (google.protobuf.struct_pb2.Value): Required. Input config parameters for the DataLabelingJob. - state (~.job_state.JobState): + state (google.cloud.aiplatform_v1beta1.types.JobState): Output only. The detailed state of the job. labeling_progress (int): Output only. Current labeling job progress percentage scaled in interval [0, 100], indicating the percentage of DataItems that has been finished. - current_spend (~.money.Money): + current_spend (google.type.money_pb2.Money): Output only. Estimated cost(in US dollars) that the DataLabelingJob has incurred to date. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this DataLabelingJob was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this DataLabelingJob was updated most recently. - error (~.status.Status): + error (google.rpc.status_pb2.Status): Output only. DataLabelingJob errors. It is only populated when job's state is ``JOB_STATE_FAILED`` or ``JOB_STATE_CANCELLED``. - labels (Sequence[~.data_labeling_job.DataLabelingJob.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.DataLabelingJob.LabelsEntry]): The labels with user-defined metadata to organize your DataLabelingJobs. @@ -124,8 +125,15 @@ class DataLabelingJob(proto.Message): specialist_pools (Sequence[str]): The SpecialistPools' resource names associated with this job. - active_learning_config (~.data_labeling_job.ActiveLearningConfig): - Paramaters that configure active learning + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a + DataLabelingJob. If set, this DataLabelingJob + will be secured by this key. + Note: Annotations created in the DataLabelingJob + are associated with the EncryptionSpec of the + Dataset they are exported to. + active_learning_config (google.cloud.aiplatform_v1beta1.types.ActiveLearningConfig): + Parameters that configure active learning pipeline. Active learning will label the data incrementally via several iterations. For every iteration, it will select a batch of data based @@ -164,13 +172,17 @@ class DataLabelingJob(proto.Message): specialist_pools = proto.RepeatedField(proto.STRING, number=16) + encryption_spec = proto.Field( + proto.MESSAGE, number=20, message=gca_encryption_spec.EncryptionSpec, + ) + active_learning_config = proto.Field( proto.MESSAGE, number=21, message="ActiveLearningConfig", ) class ActiveLearningConfig(proto.Message): - r"""Paramaters that configure active learning pipeline. Active + r"""Parameters that configure active learning pipeline. Active learning will label the data incrementally by several iterations. For every iteration, it will select a batch of data based on the sampling strategy. @@ -181,12 +193,12 @@ class ActiveLearningConfig(proto.Message): max_data_item_percentage (int): Max percent of total DataItems for human labeling. - sample_config (~.data_labeling_job.SampleConfig): + sample_config (google.cloud.aiplatform_v1beta1.types.SampleConfig): Active learning data sampling config. For every active learning labeling iteration, it will select a batch of data based on the sampling strategy. - training_config (~.data_labeling_job.TrainingConfig): + training_config (google.cloud.aiplatform_v1beta1.types.TrainingConfig): CMLE training config. For every active learning labeling iteration, system will train a machine learning model on CMLE. The trained @@ -220,7 +232,7 @@ class SampleConfig(proto.Message): The percentage of data needed to be labeled in each following batch (except the first batch). - sample_strategy (~.data_labeling_job.SampleConfig.SampleStrategy): + sample_strategy (google.cloud.aiplatform_v1beta1.types.SampleConfig.SampleStrategy): Field to chose sampling strategy. Sampling strategy will decide which data should be selected for human labeling in every batch. diff --git a/google/cloud/aiplatform_v1beta1/types/dataset.py b/google/cloud/aiplatform_v1beta1/types/dataset.py index 76f6462f40..2ca1244527 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import io from google.protobuf import struct_pb2 as struct # type: ignore from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -48,20 +49,20 @@ class Dataset(proto.Message): schema files that can be used here are found in gs://google-cloud- aiplatform/schema/dataset/metadata/. - metadata (~.struct.Value): + metadata (google.protobuf.struct_pb2.Value): Required. Additional information about the Dataset. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Dataset was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Dataset was last updated. etag (str): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens. - labels (Sequence[~.dataset.Dataset.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Dataset.LabelsEntry]): The labels with user-defined metadata to organize your Datasets. @@ -80,6 +81,11 @@ class Dataset(proto.Message): output only, its value is the [metadata_schema's][google.cloud.aiplatform.v1beta1.Dataset.metadata_schema_uri] title. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a + Dataset. If set, this Dataset and all sub- + resources of this Dataset will be secured by + this key. """ name = proto.Field(proto.STRING, number=1) @@ -98,6 +104,10 @@ class Dataset(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=7) + encryption_spec = proto.Field( + proto.MESSAGE, number=11, message=gca_encryption_spec.EncryptionSpec, + ) + class ImportDataConfig(proto.Message): r"""Describes the location from where we import data into a @@ -105,10 +115,10 @@ class ImportDataConfig(proto.Message): DataItems and the Annotations. Attributes: - gcs_source (~.io.GcsSource): + gcs_source (google.cloud.aiplatform_v1beta1.types.GcsSource): The Google Cloud Storage location for the input content. - data_item_labels (Sequence[~.dataset.ImportDataConfig.DataItemLabelsEntry]): + data_item_labels (Sequence[google.cloud.aiplatform_v1beta1.types.ImportDataConfig.DataItemLabelsEntry]): Labels that will be applied to newly imported DataItems. If an identical DataItem as one being imported already exists in the Dataset, then these labels will be appended to these @@ -145,7 +155,7 @@ class ExportDataConfig(proto.Message): destination of the export and how to export. Attributes: - gcs_destination (~.io.GcsDestination): + gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): The Google Cloud Storage location where the output is to be written to. In the given directory a new directory will be created with name: diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 7160b7b52f..91c64f4b5d 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -59,7 +59,7 @@ class CreateDatasetRequest(proto.Message): Required. The resource name of the Location to create the Dataset in. Format: ``projects/{project}/locations/{location}`` - dataset (~.gca_dataset.Dataset): + dataset (google.cloud.aiplatform_v1beta1.types.Dataset): Required. The Dataset to create. """ @@ -73,7 +73,7 @@ class CreateDatasetOperationMetadata(proto.Message): ``DatasetService.CreateDataset``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ @@ -89,7 +89,7 @@ class GetDatasetRequest(proto.Message): Attributes: name (str): Required. The name of the Dataset resource. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -103,14 +103,13 @@ class UpdateDatasetRequest(proto.Message): ``DatasetService.UpdateDataset``. Attributes: - dataset (~.gca_dataset.Dataset): + dataset (google.cloud.aiplatform_v1beta1.types.Dataset): Required. The Dataset which replaces the resource on the server. - update_mask (~.field_mask.FieldMask): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see - - [FieldMask](https://tinyurl.com/dev-google-protobuf#google.protobuf.FieldMask). + `FieldMask `__. Updatable fields: - ``display_name`` @@ -132,12 +131,27 @@ class ListDatasetsRequest(proto.Message): Required. The name of the Dataset's parent resource. Format: ``projects/{project}/locations/{location}`` filter (str): - The standard list filter. + An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. + + - ``display_name``: supports = and != + - ``metadata_schema_uri``: supports = and != + - ``labels`` supports general map functions that is: + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. + + Some examples: + + - ``displayName="myDisplayName"`` + - ``labels.myKey="myValue"`` page_size (int): The standard list page size. page_token (str): The standard list page token. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. order_by (str): A comma-separated list of fields to order by, sorted in @@ -145,7 +159,7 @@ class ListDatasetsRequest(proto.Message): descending. Supported fields: - ``display_name`` - - ``data_item_count`` \* ``create_time`` + - ``create_time`` - ``update_time`` """ @@ -167,7 +181,7 @@ class ListDatasetsResponse(proto.Message): ``DatasetService.ListDatasets``. Attributes: - datasets (Sequence[~.gca_dataset.Dataset]): + datasets (Sequence[google.cloud.aiplatform_v1beta1.types.Dataset]): A list of Datasets that matches the specified filter in the request. next_page_token (str): @@ -207,7 +221,7 @@ class ImportDataRequest(proto.Message): name (str): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` - import_configs (Sequence[~.gca_dataset.ImportDataConfig]): + import_configs (Sequence[google.cloud.aiplatform_v1beta1.types.ImportDataConfig]): Required. The desired input locations. The contents of all input locations will be imported in one batch. @@ -231,7 +245,7 @@ class ImportDataOperationMetadata(proto.Message): ``DatasetService.ImportData``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. """ @@ -248,7 +262,7 @@ class ExportDataRequest(proto.Message): name (str): Required. The name of the Dataset resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` - export_config (~.gca_dataset.ExportDataConfig): + export_config (google.cloud.aiplatform_v1beta1.types.ExportDataConfig): Required. The desired output location. """ @@ -277,7 +291,7 @@ class ExportDataOperationMetadata(proto.Message): ``DatasetService.ExportData``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. gcs_output_directory (str): A Google Cloud Storage directory which path @@ -307,7 +321,7 @@ class ListDataItemsRequest(proto.Message): The standard list page size. page_token (str): The standard list page token. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. order_by (str): A comma-separated list of fields to order by, @@ -333,7 +347,7 @@ class ListDataItemsResponse(proto.Message): ``DatasetService.ListDataItems``. Attributes: - data_items (Sequence[~.data_item.DataItem]): + data_items (Sequence[google.cloud.aiplatform_v1beta1.types.DataItem]): A list of DataItems that matches the specified filter in the request. next_page_token (str): @@ -360,7 +374,7 @@ class GetAnnotationSpecRequest(proto.Message): Required. The name of the AnnotationSpec resource. Format: ``projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}`` - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -385,7 +399,7 @@ class ListAnnotationsRequest(proto.Message): The standard list page size. page_token (str): The standard list page token. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. order_by (str): A comma-separated list of fields to order by, @@ -411,7 +425,7 @@ class ListAnnotationsResponse(proto.Message): ``DatasetService.ListAnnotations``. Attributes: - annotations (Sequence[~.annotation.Annotation]): + annotations (Sequence[google.cloud.aiplatform_v1beta1.types.Annotation]): A list of Annotations that matches the specified filter in the request. next_page_token (str): diff --git a/google/cloud/aiplatform_v1beta1/types/encryption_spec.py b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py new file mode 100644 index 0000000000..0d41d39a0b --- /dev/null +++ b/google/cloud/aiplatform_v1beta1/types/encryption_spec.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import proto # type: ignore + + +__protobuf__ = proto.module( + package="google.cloud.aiplatform.v1beta1", manifest={"EncryptionSpec",}, +) + + +class EncryptionSpec(proto.Message): + r"""Represents a customer-managed encryption key spec that can be + applied to a top-level resource. + + Attributes: + kms_key_name (str): + Required. The Cloud KMS resource identifier of the customer + managed encryption key used to protect a resource. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + """ + + kms_key_name = proto.Field(proto.STRING, number=1) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint.py b/google/cloud/aiplatform_v1beta1/types/endpoint.py index f1ba6ed85d..bdbcb6ff21 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import machine_resources from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -42,14 +43,14 @@ class Endpoint(proto.Message): can be consist of any UTF-8 characters. description (str): The description of the Endpoint. - deployed_models (Sequence[~.endpoint.DeployedModel]): + deployed_models (Sequence[google.cloud.aiplatform_v1beta1.types.DeployedModel]): Output only. The models deployed in this Endpoint. To add or remove DeployedModels use ``EndpointService.DeployModel`` and ``EndpointService.UndeployModel`` respectively. - traffic_split (Sequence[~.endpoint.Endpoint.TrafficSplitEntry]): + traffic_split (Sequence[google.cloud.aiplatform_v1beta1.types.Endpoint.TrafficSplitEntry]): A map from a DeployedModel's ID to the percentage of this Endpoint's traffic that should be forwarded to that DeployedModel. @@ -63,7 +64,7 @@ class Endpoint(proto.Message): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens. - labels (Sequence[~.endpoint.Endpoint.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Endpoint.LabelsEntry]): The labels with user-defined metadata to organize your Endpoints. Label keys and values can be no longer than 64 @@ -73,12 +74,17 @@ class Endpoint(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Endpoint was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Endpoint was last updated. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for an + Endpoint. If set, this Endpoint and all sub- + resources of this Endpoint will be secured by + this key. """ name = proto.Field(proto.STRING, number=1) @@ -101,17 +107,21 @@ class Endpoint(proto.Message): update_time = proto.Field(proto.MESSAGE, number=9, message=timestamp.Timestamp,) + encryption_spec = proto.Field( + proto.MESSAGE, number=10, message=gca_encryption_spec.EncryptionSpec, + ) + class DeployedModel(proto.Message): r"""A deployment of a Model. Endpoints contain one or more DeployedModels. Attributes: - dedicated_resources (~.machine_resources.DedicatedResources): + dedicated_resources (google.cloud.aiplatform_v1beta1.types.DedicatedResources): A description of resources that are dedicated to the DeployedModel, and that need a higher degree of manual configuration. - automatic_resources (~.machine_resources.AutomaticResources): + automatic_resources (google.cloud.aiplatform_v1beta1.types.AutomaticResources): A description of resources that to large degree are decided by AI Platform, and require only a modest additional configuration. @@ -125,10 +135,10 @@ class DeployedModel(proto.Message): display_name (str): The display name of the DeployedModel. If not provided upon creation, the Model's display_name is used. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when the DeployedModel was created. - explanation_spec (~.explanation.ExplanationSpec): + explanation_spec (google.cloud.aiplatform_v1beta1.types.ExplanationSpec): Explanation configuration for this DeployedModel. When deploying a Model using @@ -160,7 +170,7 @@ class DeployedModel(proto.Message): send ``stderr`` and ``stdout`` streams to Stackdriver Logging. - Only supported for custom-trained Models and AutoML Tables + Only supported for custom-trained Models and AutoML Tabular Models. enable_access_logging (bool): These logs are like standard server access diff --git a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py index 4bc9f35594..72d1063334 100644 --- a/google/cloud/aiplatform_v1beta1/types/endpoint_service.py +++ b/google/cloud/aiplatform_v1beta1/types/endpoint_service.py @@ -52,7 +52,7 @@ class CreateEndpointRequest(proto.Message): Required. The resource name of the Location to create the Endpoint in. Format: ``projects/{project}/locations/{location}`` - endpoint (~.gca_endpoint.Endpoint): + endpoint (google.cloud.aiplatform_v1beta1.types.Endpoint): Required. The Endpoint to create. """ @@ -66,7 +66,7 @@ class CreateEndpointOperationMetadata(proto.Message): ``EndpointService.CreateEndpoint``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ @@ -103,23 +103,21 @@ class ListEndpointsRequest(proto.Message): supported. - ``endpoint`` supports = and !=. ``endpoint`` represents - the Endpoint ID, ie. the last segment of the Endpoint's + the Endpoint ID, i.e. the last segment of the Endpoint's [resource name][google.cloud.aiplatform.v1beta1.Endpoint.name]. - - ``display_name`` supports =, != and regex() (uses - `re2 `__ - syntax) + - ``display_name`` supports = and, != - ``labels`` supports general map functions that is: - ``labels.key=value`` - key:value equality - ``labels.key:* or labels:key - key existence A key including a space must be quoted.``\ labels."a - key"`. + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. Some examples: - ``endpoint=1`` - ``displayName="myDisplayName"`` - - \`regex(display_name, "^A") -> The display name starts - with an A. - ``labels.myKey="myValue"`` page_size (int): Optional. The standard list page size. @@ -130,7 +128,7 @@ class ListEndpointsRequest(proto.Message): of the previous ``EndpointService.ListEndpoints`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Optional. Mask specifying which fields to read. """ @@ -151,7 +149,7 @@ class ListEndpointsResponse(proto.Message): ``EndpointService.ListEndpoints``. Attributes: - endpoints (Sequence[~.gca_endpoint.Endpoint]): + endpoints (Sequence[google.cloud.aiplatform_v1beta1.types.Endpoint]): List of Endpoints in the requested page. next_page_token (str): A token to retrieve next page of results. Pass to @@ -175,12 +173,12 @@ class UpdateEndpointRequest(proto.Message): ``EndpointService.UpdateEndpoint``. Attributes: - endpoint (~.gca_endpoint.Endpoint): + endpoint (google.cloud.aiplatform_v1beta1.types.Endpoint): Required. The Endpoint which replaces the resource on the server. - update_mask (~.field_mask.FieldMask): - Required. The update mask applies to the - resource. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The update mask applies to the resource. See + `FieldMask `__. """ endpoint = proto.Field(proto.MESSAGE, number=1, message=gca_endpoint.Endpoint,) @@ -211,14 +209,14 @@ class DeployModelRequest(proto.Message): Required. The name of the Endpoint resource into which to deploy a Model. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - deployed_model (~.gca_endpoint.DeployedModel): + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): Required. The DeployedModel to be created within the Endpoint. Note that ``Endpoint.traffic_split`` must be updated for the DeployedModel to start receiving traffic, either as part of this call, or via ``EndpointService.UpdateEndpoint``. - traffic_split (Sequence[~.endpoint_service.DeployModelRequest.TrafficSplitEntry]): + traffic_split (Sequence[google.cloud.aiplatform_v1beta1.types.DeployModelRequest.TrafficSplitEntry]): A map from a DeployedModel's ID to the percentage of this Endpoint's traffic that should be forwarded to that DeployedModel. @@ -250,7 +248,7 @@ class DeployModelResponse(proto.Message): ``EndpointService.DeployModel``. Attributes: - deployed_model (~.gca_endpoint.DeployedModel): + deployed_model (google.cloud.aiplatform_v1beta1.types.DeployedModel): The DeployedModel that had been deployed in the Endpoint. """ @@ -265,7 +263,7 @@ class DeployModelOperationMetadata(proto.Message): ``EndpointService.DeployModel``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ @@ -286,7 +284,7 @@ class UndeployModelRequest(proto.Message): deployed_model_id (str): Required. The ID of the DeployedModel to be undeployed from the Endpoint. - traffic_split (Sequence[~.endpoint_service.UndeployModelRequest.TrafficSplitEntry]): + traffic_split (Sequence[google.cloud.aiplatform_v1beta1.types.UndeployModelRequest.TrafficSplitEntry]): If this field is provided, then the Endpoint's ``traffic_split`` will be overwritten with it. If last DeployedModel is being @@ -315,7 +313,7 @@ class UndeployModelOperationMetadata(proto.Message): ``EndpointService.UndeployModel``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ diff --git a/google/cloud/aiplatform_v1beta1/types/env_var.py b/google/cloud/aiplatform_v1beta1/types/env_var.py index 207e8275cd..0d2c3769ff 100644 --- a/google/cloud/aiplatform_v1beta1/types/env_var.py +++ b/google/cloud/aiplatform_v1beta1/types/env_var.py @@ -32,14 +32,14 @@ class EnvVar(proto.Message): Required. Name of the environment variable. Must be a valid C identifier. value (str): - Variables that reference a $(VAR_NAME) are expanded using - the previous defined environment variables in the container - and any service environment variables. If a variable cannot - be resolved, the reference in the input string will be - unchanged. The $(VAR_NAME) syntax can be escaped with a - double $$, ie: $$(VAR_NAME). Escaped references will never - be expanded, regardless of whether the variable exists or - not. + Required. Variables that reference a $(VAR_NAME) are + expanded using the previous defined environment variables in + the container and any service environment variables. If a + variable cannot be resolved, the reference in the input + string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped + references will never be expanded, regardless of whether the + variable exists or not. """ name = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation.py b/google/cloud/aiplatform_v1beta1/types/explanation.py index 7a495fff1e..d9b48b60ab 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation.py @@ -35,6 +35,8 @@ "XraiAttribution", "SmoothGradConfig", "FeatureNoiseSigma", + "ExplanationSpecOverride", + "ExplanationMetadataOverride", }, ) @@ -46,7 +48,7 @@ class Explanation(proto.Message): ``instance``. Attributes: - attributions (Sequence[~.explanation.Attribution]): + attributions (Sequence[google.cloud.aiplatform_v1beta1.types.Attribution]): Output only. Feature attributions grouped by predicted outputs. @@ -79,8 +81,8 @@ class ModelExplanation(proto.Message): instances. Attributes: - mean_attributions (Sequence[~.explanation.Attribution]): - Output only. Aggregated attributions explaning the Model's + mean_attributions (Sequence[google.cloud.aiplatform_v1beta1.types.Attribution]): + Output only. Aggregated attributions explaining the Model's prediction outputs over the set of instances. The attributions are grouped by outputs. @@ -140,7 +142,7 @@ class Attribution(proto.Message): If the Model predicted output has multiple dimensions, this is the value in the output located by ``output_index``. - feature_attributions (~.struct.Value): + feature_attributions (google.protobuf.struct_pb2.Value): Output only. Attributions of each explained feature. Features are extracted from the [prediction instances][google.cloud.aiplatform.v1beta1.ExplainRequest.instances] @@ -190,9 +192,9 @@ class Attribution(proto.Message): of the output vector. Indices start from 0. output_display_name (str): Output only. The display name of the output identified by - ``output_index``, - e.g. the predicted class name by a multi-classification - Model. + ``output_index``. + For example, the predicted class name by a + multi-classification Model. This field is only populated iff the Model predicts display names as a separate field along with the explained output. @@ -204,25 +206,25 @@ class Attribution(proto.Message): caused by approximation used in the explanation method. Lower value means more precise attributions. - - For [Sampled Shapley - attribution][ExplanationParameters.sampled_shapley_attribution], + - For Sampled Shapley + ``attribution``, increasing ``path_count`` - may reduce the error. - - For [Integrated Gradients - attribution][ExplanationParameters.integrated_gradients_attribution], + might reduce the error. + - For Integrated Gradients + ``attribution``, increasing ``step_count`` - may reduce the error. + might reduce the error. - For [XRAI - attribution][ExplanationParameters.xrai_attribution], + attribution][google.cloud.aiplatform.v1beta1.ExplanationParameters.xrai_attribution], increasing ``step_count`` - may reduce the error. - - Refer to AI Explanations Whitepaper for more details: + might reduce the error. - https://storage.googleapis.com/cloud-ai-whitepapers/AI%20Explainability%20Whitepaper.pdf + See `this + introduction `__ + for more information. output_name (str): Output only. Name of the explain output. Specified as the key in @@ -248,10 +250,10 @@ class ExplanationSpec(proto.Message): r"""Specification of Model explanation. Attributes: - parameters (~.explanation.ExplanationParameters): + parameters (google.cloud.aiplatform_v1beta1.types.ExplanationParameters): Required. Parameters that configure explaining of the Model's predictions. - metadata (~.explanation_metadata.ExplanationMetadata): + metadata (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata): Required. Metadata describing the Model's input and output for explanation. """ @@ -267,7 +269,7 @@ class ExplanationParameters(proto.Message): r"""Parameters to configure explaining for Model's predictions. Attributes: - sampled_shapley_attribution (~.explanation.SampledShapleyAttribution): + sampled_shapley_attribution (google.cloud.aiplatform_v1beta1.types.SampledShapleyAttribution): An attribution method that approximates Shapley values for features that contribute to the label being predicted. A sampling strategy @@ -275,13 +277,13 @@ class ExplanationParameters(proto.Message): considering all subsets of features. Refer to this paper for model details: https://arxiv.org/abs/1306.4265. - integrated_gradients_attribution (~.explanation.IntegratedGradientsAttribution): + integrated_gradients_attribution (google.cloud.aiplatform_v1beta1.types.IntegratedGradientsAttribution): An attribution method that computes Aumann- hapley values taking advantage of the model's fully differentiable structure. Refer to this paper for more details: https://arxiv.org/abs/1703.01365 - xrai_attribution (~.explanation.XraiAttribution): + xrai_attribution (google.cloud.aiplatform_v1beta1.types.XraiAttribution): An attribution method that redistributes Integrated Gradients attribution to segmented regions, taking advantage of the model's fully @@ -301,11 +303,11 @@ class ExplanationParameters(proto.Message): to Models that predicts more than one outputs (e,g, multi-class Models). When set to -1, returns explanations for all outputs. - output_indices (~.struct.ListValue): + output_indices (google.protobuf.struct_pb2.ListValue): If populated, only returns attributions that have - ``output_index`` contained in - output_indices. It must be an ndarray of integers, with the - same shape of the output it's explaining. + ``output_index`` + contained in output_indices. It must be an ndarray of + integers, with the same shape of the output it's explaining. If not populated, returns attributions for ``top_k`` @@ -367,7 +369,7 @@ class IntegratedGradientsAttribution(proto.Message): range. Valid range of its value is [1, 100], inclusively. - smooth_grad_config (~.explanation.SmoothGradConfig): + smooth_grad_config (google.cloud.aiplatform_v1beta1.types.SmoothGradConfig): Config for SmoothGrad approximation of gradients. When enabled, the gradients are approximated by @@ -387,12 +389,11 @@ class IntegratedGradientsAttribution(proto.Message): class XraiAttribution(proto.Message): r"""An explanation method that redistributes Integrated Gradients - attributions to segmented regions, taking advantage of the model's - fully differentiable structure. Refer to this paper for more - details: https://arxiv.org/abs/1906.02825 + attributions to segmented regions, taking advantage of the + model's fully differentiable structure. Refer to this paper for + more details: https://arxiv.org/abs/1906.02825 - Only supports image Models (``modality`` is - IMAGE). + Supported only by image Models. Attributes: step_count (int): @@ -402,7 +403,7 @@ class XraiAttribution(proto.Message): error range. Valid range of its value is [1, 100], inclusively. - smooth_grad_config (~.explanation.SmoothGradConfig): + smooth_grad_config (google.cloud.aiplatform_v1beta1.types.SmoothGradConfig): Config for SmoothGrad approximation of gradients. When enabled, the gradients are approximated by @@ -434,10 +435,8 @@ class SmoothGradConfig(proto.Message): to all the features. Use this field when all features are normalized to have the same distribution: scale to range [0, 1], [-1, 1] or z-scoring, where features are normalized to - have 0-mean and 1-variance. Refer to this doc for more - details about normalization: - - https://developers.google.com/machine-learning/data-prep/transform/normalization. + have 0-mean and 1-variance. For more details about + normalization: https://tinyurl.com/dgc-normalization. For best results the recommended value is about 10% - 20% of the standard deviation of the input feature. Refer to @@ -447,7 +446,7 @@ class SmoothGradConfig(proto.Message): If the distribution is different per feature, set ``feature_noise_sigma`` instead for each feature. - feature_noise_sigma (~.explanation.FeatureNoiseSigma): + feature_noise_sigma (google.cloud.aiplatform_v1beta1.types.FeatureNoiseSigma): This is similar to ``noise_sigma``, but provides additional flexibility. A separate noise sigma @@ -481,7 +480,7 @@ class FeatureNoiseSigma(proto.Message): to interpolated inputs prior to computing gradients. Attributes: - noise_sigma (Sequence[~.explanation.FeatureNoiseSigma.NoiseSigmaForFeature]): + noise_sigma (Sequence[google.cloud.aiplatform_v1beta1.types.FeatureNoiseSigma.NoiseSigmaForFeature]): Noise sigma per feature. No noise is added to features that are not set. """ @@ -512,4 +511,72 @@ class NoiseSigmaForFeature(proto.Message): ) +class ExplanationSpecOverride(proto.Message): + r"""The + ``ExplanationSpec`` + entries that can be overridden at [online + explanation]``PredictionService.Explain`` + time. + + Attributes: + parameters (google.cloud.aiplatform_v1beta1.types.ExplanationParameters): + The parameters to be overridden. Note that the + ``method`` + cannot be changed. If not specified, no parameter is + overridden. + metadata (google.cloud.aiplatform_v1beta1.types.ExplanationMetadataOverride): + The metadata to be overridden. If not + specified, no metadata is overridden. + """ + + parameters = proto.Field(proto.MESSAGE, number=1, message="ExplanationParameters",) + + metadata = proto.Field( + proto.MESSAGE, number=2, message="ExplanationMetadataOverride", + ) + + +class ExplanationMetadataOverride(proto.Message): + r"""The + ``ExplanationMetadata`` + entries that can be overridden at [online + explanation][google.cloud.aiplatform.v1beta1.PredictionService.Explain] + time. + + Attributes: + inputs (Sequence[google.cloud.aiplatform_v1beta1.types.ExplanationMetadataOverride.InputsEntry]): + Required. Overrides the [input + metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.inputs] + of the features. The key is the name of the feature to be + overridden. The keys specified here must exist in the input + metadata to be overridden. If a feature is not specified + here, the corresponding feature's input metadata is not + overridden. + """ + + class InputMetadataOverride(proto.Message): + r"""The [input + metadata][google.cloud.aiplatform.v1beta1.ExplanationMetadata.InputMetadata] + entries to be overridden. + + Attributes: + input_baselines (Sequence[google.protobuf.struct_pb2.Value]): + Baseline inputs for this feature. + + This overrides the ``input_baseline`` field of the + ``ExplanationMetadata.InputMetadata`` + object of the corresponding feature's input metadata. If + it's not specified, the original baselines are not + overridden. + """ + + input_baselines = proto.RepeatedField( + proto.MESSAGE, number=1, message=struct.Value, + ) + + inputs = proto.MapField( + proto.STRING, proto.MESSAGE, number=1, message=InputMetadataOverride, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py index 7261c064f8..dcbc32a4f5 100644 --- a/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py +++ b/google/cloud/aiplatform_v1beta1/types/explanation_metadata.py @@ -31,7 +31,7 @@ class ExplanationMetadata(proto.Message): explanation. Attributes: - inputs (Sequence[~.explanation_metadata.ExplanationMetadata.InputsEntry]): + inputs (Sequence[google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputsEntry]): Required. Map from feature names to feature input metadata. Keys are the name of the features. Values are the specification of the feature. @@ -42,13 +42,13 @@ class ExplanationMetadata(proto.Message): The baseline of the empty feature is chosen by AI Platform. For AI Platform provided Tensorflow images, the key can be - any friendly name of the feature . Once specified, [ - featureAttributions][Attribution.feature_attributions] will - be keyed by this key (if not grouped with another feature). + any friendly name of the feature. Once specified, + ``featureAttributions`` + are keyed by this key (if not grouped with another feature). For custom images, the key must match with the key in ``instance``. - outputs (Sequence[~.explanation_metadata.ExplanationMetadata.OutputsEntry]): + outputs (Sequence[google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.OutputsEntry]): Required. Map from output names to output metadata. For AI Platform provided Tensorflow images, keys @@ -80,7 +80,7 @@ class InputMetadata(proto.Message): images for Tensorflow. Attributes: - input_baselines (Sequence[~.struct.Value]): + input_baselines (Sequence[google.protobuf.struct_pb2.Value]): Baseline inputs for this feature. If no baseline is specified, AI Platform chooses the @@ -105,13 +105,13 @@ class InputMetadata(proto.Message): Name of the input tensor for this feature. Required and is only applicable to AI Platform provided images for Tensorflow. - encoding (~.explanation_metadata.ExplanationMetadata.InputMetadata.Encoding): + encoding (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Encoding): Defines how the feature is encoded into the input tensor. Defaults to IDENTITY. modality (str): Modality of the feature. Valid values are: numeric, image. Defaults to numeric. - feature_value_domain (~.explanation_metadata.ExplanationMetadata.InputMetadata.FeatureValueDomain): + feature_value_domain (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.FeatureValueDomain): The domain details of the input feature value. Like min/max, original mean or standard deviation if normalized. @@ -140,13 +140,13 @@ class InputMetadata(proto.Message): An encoded tensor is generated if the input tensor is encoded by a lookup table. - encoded_baselines (Sequence[~.struct.Value]): + encoded_baselines (Sequence[google.protobuf.struct_pb2.Value]): A list of baselines for the encoded tensor. The shape of each baseline should match the shape of the encoded tensor. If a scalar is provided, AI Platform broadcast to the same shape as the encoded tensor. - visualization (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization): + visualization (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization): Visualization configurations for image explanation. group_name (str): @@ -211,17 +211,17 @@ class Visualization(proto.Message): r"""Visualization configurations for image explanation. Attributes: - type_ (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Type): + type_ (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization.Type): Type of the image visualization. Only applicable to [Integrated Gradients attribution] [ExplanationParameters.integrated_gradients_attribution]. OUTLINES shows regions of attribution, while PIXELS shows per-pixel attribution. Defaults to OUTLINES. - polarity (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.Polarity): + polarity (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization.Polarity): Whether to only highlight pixels with positive contributions, negative or both. Defaults to POSITIVE. - color_map (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.ColorMap): + color_map (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization.ColorMap): The color scheme used for the highlighted areas. Defaults to PINK_GREEN for [Integrated Gradients @@ -243,7 +243,7 @@ class Visualization(proto.Message): Excludes attributions below the specified percentile, from the highlighted areas. Defaults to 35. - overlay_type (~.explanation_metadata.ExplanationMetadata.InputMetadata.Visualization.OverlayType): + overlay_type (google.cloud.aiplatform_v1beta1.types.ExplanationMetadata.InputMetadata.Visualization.OverlayType): How the original image is displayed in the visualization. Adjusting the overlay can help increase visual clarity if the original image @@ -357,7 +357,7 @@ class OutputMetadata(proto.Message): r"""Metadata of the prediction output to be explained. Attributes: - index_display_name_mapping (~.struct.Value): + index_display_name_mapping (google.protobuf.struct_pb2.Value): Static mapping between the index and display name. Use this if the outputs are a deterministic n-dimensional diff --git a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py index 78af635e79..55978a409e 100644 --- a/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py +++ b/google/cloud/aiplatform_v1beta1/types/hyperparameter_tuning_job.py @@ -19,6 +19,7 @@ from google.cloud.aiplatform_v1beta1.types import custom_job +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import job_state from google.cloud.aiplatform_v1beta1.types import study from google.protobuf import timestamp_pb2 as timestamp # type: ignore @@ -44,7 +45,7 @@ class HyperparameterTuningJob(proto.Message): HyperparameterTuningJob. The name can be up to 128 characters long and can be consist of any UTF-8 characters. - study_spec (~.study.StudySpec): + study_spec (google.cloud.aiplatform_v1beta1.types.StudySpec): Required. Study configuration of the HyperparameterTuningJob. max_trial_count (int): @@ -57,33 +58,33 @@ class HyperparameterTuningJob(proto.Message): seen before failing the HyperparameterTuningJob. If set to 0, AI Platform decides how many Trials must fail before the whole job fails. - trial_job_spec (~.custom_job.CustomJobSpec): + trial_job_spec (google.cloud.aiplatform_v1beta1.types.CustomJobSpec): Required. The spec of a trial job. The same spec applies to the CustomJobs created in all the trials. - trials (Sequence[~.study.Trial]): + trials (Sequence[google.cloud.aiplatform_v1beta1.types.Trial]): Output only. Trials of the HyperparameterTuningJob. - state (~.job_state.JobState): + state (google.cloud.aiplatform_v1beta1.types.JobState): Output only. The detailed state of the job. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the HyperparameterTuningJob was created. - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the HyperparameterTuningJob for the first time entered the ``JOB_STATE_RUNNING`` state. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the HyperparameterTuningJob entered any of the following states: ``JOB_STATE_SUCCEEDED``, ``JOB_STATE_FAILED``, ``JOB_STATE_CANCELLED``. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the HyperparameterTuningJob was most recently updated. - error (~.status.Status): + error (google.rpc.status_pb2.Status): Output only. Only populated when job's state is JOB_STATE_FAILED or JOB_STATE_CANCELLED. - labels (Sequence[~.hyperparameter_tuning_job.HyperparameterTuningJob.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob.LabelsEntry]): The labels with user-defined metadata to organize HyperparameterTuningJobs. Label keys and values can be no longer than 64 @@ -93,6 +94,12 @@ class HyperparameterTuningJob(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key options for a + HyperparameterTuningJob. If this is set, then + all resources created by the + HyperparameterTuningJob will be encrypted with + the provided encryption key. """ name = proto.Field(proto.STRING, number=1) @@ -127,5 +134,9 @@ class HyperparameterTuningJob(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=16) + encryption_spec = proto.Field( + proto.MESSAGE, number=17, message=gca_encryption_spec.EncryptionSpec, + ) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index f5fcc170f9..b032cc2bae 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -79,10 +79,17 @@ class BigQueryDestination(proto.Message): Attributes: output_uri (str): - Required. BigQuery URI to a project, up to 2000 characters - long. Accepted forms: + Required. BigQuery URI to a project or table, up to 2000 + characters long. + + When only project is specified, Dataset and Table is + created. When full table reference is specified, Dataset + must exist and table must not exist. - - BigQuery path. For example: ``bq://projectId``. + Accepted forms: + + - BigQuery path. For example: ``bq://projectId`` or + ``bq://projectId.bqDatasetId.bqTableId``. """ output_uri = proto.Field(proto.STRING, number=1) diff --git a/google/cloud/aiplatform_v1beta1/types/job_service.py b/google/cloud/aiplatform_v1beta1/types/job_service.py index f64f07cbe3..ead3e7c765 100644 --- a/google/cloud/aiplatform_v1beta1/types/job_service.py +++ b/google/cloud/aiplatform_v1beta1/types/job_service.py @@ -71,7 +71,7 @@ class CreateCustomJobRequest(proto.Message): Required. The resource name of the Location to create the CustomJob in. Format: ``projects/{project}/locations/{location}`` - custom_job (~.gca_custom_job.CustomJob): + custom_job (google.cloud.aiplatform_v1beta1.types.CustomJob): Required. The CustomJob to create. """ @@ -128,7 +128,7 @@ class ListCustomJobsRequest(proto.Message): of the previous ``JobService.ListCustomJobs`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -148,7 +148,7 @@ class ListCustomJobsResponse(proto.Message): ``JobService.ListCustomJobs`` Attributes: - custom_jobs (Sequence[~.gca_custom_job.CustomJob]): + custom_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.CustomJob]): List of CustomJobs in the requested page. next_page_token (str): A token to retrieve next page of results. Pass to @@ -202,7 +202,7 @@ class CreateDataLabelingJobRequest(proto.Message): parent (str): Required. The parent of the DataLabelingJob. Format: ``projects/{project}/locations/{location}`` - data_labeling_job (~.gca_data_labeling_job.DataLabelingJob): + data_labeling_job (google.cloud.aiplatform_v1beta1.types.DataLabelingJob): Required. The DataLabelingJob to create. """ @@ -255,7 +255,7 @@ class ListDataLabelingJobsRequest(proto.Message): The standard list page size. page_token (str): The standard list page token. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. FieldMask represents a set of symbolic field paths. For example, the mask can be ``paths: "name"``. The "name" here is a field in @@ -285,7 +285,7 @@ class ListDataLabelingJobsResponse(proto.Message): ``JobService.ListDataLabelingJobs``. Attributes: - data_labeling_jobs (Sequence[~.gca_data_labeling_job.DataLabelingJob]): + data_labeling_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.DataLabelingJob]): A list of DataLabelingJobs that matches the specified filter in the request. next_page_token (str): @@ -341,7 +341,7 @@ class CreateHyperparameterTuningJobRequest(proto.Message): Required. The resource name of the Location to create the HyperparameterTuningJob in. Format: ``projects/{project}/locations/{location}`` - hyperparameter_tuning_job (~.gca_hyperparameter_tuning_job.HyperparameterTuningJob): + hyperparameter_tuning_job (google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob): Required. The HyperparameterTuningJob to create. """ @@ -405,7 +405,7 @@ class ListHyperparameterTuningJobsRequest(proto.Message): of the previous ``JobService.ListHyperparameterTuningJobs`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -425,7 +425,7 @@ class ListHyperparameterTuningJobsResponse(proto.Message): ``JobService.ListHyperparameterTuningJobs`` Attributes: - hyperparameter_tuning_jobs (Sequence[~.gca_hyperparameter_tuning_job.HyperparameterTuningJob]): + hyperparameter_tuning_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.HyperparameterTuningJob]): List of HyperparameterTuningJobs in the requested page. ``HyperparameterTuningJob.trials`` of the jobs will be not be returned. @@ -487,7 +487,7 @@ class CreateBatchPredictionJobRequest(proto.Message): Required. The resource name of the Location to create the BatchPredictionJob in. Format: ``projects/{project}/locations/{location}`` - batch_prediction_job (~.gca_batch_prediction_job.BatchPredictionJob): + batch_prediction_job (google.cloud.aiplatform_v1beta1.types.BatchPredictionJob): Required. The BatchPredictionJob to create. """ @@ -548,7 +548,7 @@ class ListBatchPredictionJobsRequest(proto.Message): of the previous ``JobService.ListBatchPredictionJobs`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -568,7 +568,7 @@ class ListBatchPredictionJobsResponse(proto.Message): ``JobService.ListBatchPredictionJobs`` Attributes: - batch_prediction_jobs (Sequence[~.gca_batch_prediction_job.BatchPredictionJob]): + batch_prediction_jobs (Sequence[google.cloud.aiplatform_v1beta1.types.BatchPredictionJob]): List of BatchPredictionJobs in the requested page. next_page_token (str): diff --git a/google/cloud/aiplatform_v1beta1/types/machine_resources.py b/google/cloud/aiplatform_v1beta1/types/machine_resources.py index c71aca024e..50dc4b3eef 100644 --- a/google/cloud/aiplatform_v1beta1/types/machine_resources.py +++ b/google/cloud/aiplatform_v1beta1/types/machine_resources.py @@ -41,46 +41,21 @@ class MachineSpec(proto.Message): Attributes: machine_type (str): - Immutable. The type of the machine. Following machine types - are supported: - - - ``n1-standard-2`` - - - ``n1-standard-4`` - - - ``n1-standard-8`` - - - ``n1-standard-16`` - - - ``n1-standard-32`` - - - ``n1-highmem-2`` - - - ``n1-highmem-4`` - - - ``n1-highmem-8`` - - - ``n1-highmem-16`` - - - ``n1-highmem-32`` - - - ``n1-highcpu-2`` - - - ``n1-highcpu-4`` - - - ``n1-highcpu-8`` - - - ``n1-highcpu-16`` - - - ``n1-highcpu-32`` - - When used for [DeployedMode][] this field is optional and - the default value is ``n1-standard-2``. If used for + Immutable. The type of the machine. For the machine types + supported for prediction, see + https://tinyurl.com/aip-docs/predictions/machine-types. For + machine types supported for creating a custom training job, + see https://tinyurl.com/aip-docs/training/configure-compute. + + For + ``DeployedModel`` + this field is optional, and the default value is + ``n1-standard-2``. For ``BatchPredictionJob`` or as part of ``WorkerPoolSpec`` this field is required. - accelerator_type (~.gca_accelerator_type.AcceleratorType): + accelerator_type (google.cloud.aiplatform_v1beta1.types.AcceleratorType): Immutable. The type of accelerator(s) that may be attached to the machine as per ``accelerator_count``. @@ -104,7 +79,7 @@ class DedicatedResources(proto.Message): configuration. Attributes: - machine_spec (~.machine_resources.MachineSpec): + machine_spec (google.cloud.aiplatform_v1beta1.types.MachineSpec): Required. Immutable. The specification of a single machine used by the prediction. min_replica_count (int): @@ -181,7 +156,7 @@ class BatchDedicatedResources(proto.Message): configuration. Attributes: - machine_spec (~.machine_resources.MachineSpec): + machine_spec (google.cloud.aiplatform_v1beta1.types.MachineSpec): Required. Immutable. The specification of a single machine. starting_replica_count (int): @@ -222,10 +197,10 @@ class DiskSpec(proto.Message): Attributes: boot_disk_type (str): - Type of the boot disk (default is "pd- - tandard"). Valid values: "pd-ssd" (Persistent - Disk Solid State Drive) or "pd-standard" - (Persistent Disk Hard Disk Drive). + Type of the boot disk (default is "pd-ssd"). + Valid values: "pd-ssd" (Persistent Disk Solid + State Drive) or "pd-standard" (Persistent Disk + Hard Disk Drive). boot_disk_size_gb (int): Size in GB of the boot disk (default is 100GB). diff --git a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py index 99a6e65a42..144ff94acc 100644 --- a/google/cloud/aiplatform_v1beta1/types/migratable_resource.py +++ b/google/cloud/aiplatform_v1beta1/types/migratable_resource.py @@ -31,24 +31,24 @@ class MigratableResource(proto.Message): datalabeling.googleapis.com or ml.googleapis.com. Attributes: - ml_engine_model_version (~.migratable_resource.MigratableResource.MlEngineModelVersion): + ml_engine_model_version (google.cloud.aiplatform_v1beta1.types.MigratableResource.MlEngineModelVersion): Output only. Represents one Version in ml.googleapis.com. - automl_model (~.migratable_resource.MigratableResource.AutomlModel): + automl_model (google.cloud.aiplatform_v1beta1.types.MigratableResource.AutomlModel): Output only. Represents one Model in automl.googleapis.com. - automl_dataset (~.migratable_resource.MigratableResource.AutomlDataset): + automl_dataset (google.cloud.aiplatform_v1beta1.types.MigratableResource.AutomlDataset): Output only. Represents one Dataset in automl.googleapis.com. - data_labeling_dataset (~.migratable_resource.MigratableResource.DataLabelingDataset): + data_labeling_dataset (google.cloud.aiplatform_v1beta1.types.MigratableResource.DataLabelingDataset): Output only. Represents one Dataset in datalabeling.googleapis.com. - last_migrate_time (~.timestamp.Timestamp): + last_migrate_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when last migrate attempt on this MigratableResource started. Will not be set if there's no migrate attempt on this MigratableResource. - last_update_time (~.timestamp.Timestamp): + last_update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this MigratableResource was last updated. """ @@ -116,7 +116,7 @@ class DataLabelingDataset(proto.Message): dataset_display_name (str): The Dataset's display name in datalabeling.googleapis.com. - data_labeling_annotated_datasets (Sequence[~.migratable_resource.MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset]): + data_labeling_annotated_datasets (Sequence[google.cloud.aiplatform_v1beta1.types.MigratableResource.DataLabelingDataset.DataLabelingAnnotatedDataset]): The migratable AnnotatedDataset in datalabeling.googleapis.com belongs to the data labeling Dataset. diff --git a/google/cloud/aiplatform_v1beta1/types/migration_service.py b/google/cloud/aiplatform_v1beta1/types/migration_service.py index 46b0cdc66b..ef37006233 100644 --- a/google/cloud/aiplatform_v1beta1/types/migration_service.py +++ b/google/cloud/aiplatform_v1beta1/types/migration_service.py @@ -22,6 +22,7 @@ migratable_resource as gca_migratable_resource, ) from google.cloud.aiplatform_v1beta1.types import operation +from google.rpc import status_pb2 as status # type: ignore __protobuf__ = proto.module( @@ -54,6 +55,23 @@ class SearchMigratableResourcesRequest(proto.Message): The default and maximum value is 100. page_token (str): The standard page token. + filter (str): + Supported filters are: + + - Resource type: For a specific type of MigratableResource. + + - ``ml_engine_model_version:*`` + - ``automl_model:*``, + - ``automl_dataset:*`` + - ``data_labeling_dataset:*``. + + - Migrated or not: Filter migrated resource or not by + last_migrate_time. + + - ``last_migrate_time:*`` will filter migrated + resources. + - ``NOT last_migrate_time:*`` will filter not yet + migrated resource. """ parent = proto.Field(proto.STRING, number=1) @@ -62,13 +80,15 @@ class SearchMigratableResourcesRequest(proto.Message): page_token = proto.Field(proto.STRING, number=3) + filter = proto.Field(proto.STRING, number=4) + class SearchMigratableResourcesResponse(proto.Message): r"""Response message for ``MigrationService.SearchMigratableResources``. Attributes: - migratable_resources (Sequence[~.gca_migratable_resource.MigratableResource]): + migratable_resources (Sequence[google.cloud.aiplatform_v1beta1.types.MigratableResource]): All migratable resources that can be migrated to the location specified in the request. next_page_token (str): @@ -96,7 +116,7 @@ class BatchMigrateResourcesRequest(proto.Message): parent (str): Required. The location of the migrated resource will live in. Format: ``projects/{project}/locations/{location}`` - migrate_resource_requests (Sequence[~.migration_service.MigrateResourceRequest]): + migrate_resource_requests (Sequence[google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest]): Required. The request messages specifying the resources to migrate. They must be in the same location as the destination. Up to 50 resources @@ -116,16 +136,16 @@ class MigrateResourceRequest(proto.Message): Platform. Attributes: - migrate_ml_engine_model_version_config (~.migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig): + migrate_ml_engine_model_version_config (google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest.MigrateMlEngineModelVersionConfig): Config for migrating Version in ml.googleapis.com to AI Platform's Model. - migrate_automl_model_config (~.migration_service.MigrateResourceRequest.MigrateAutomlModelConfig): + migrate_automl_model_config (google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest.MigrateAutomlModelConfig): Config for migrating Model in automl.googleapis.com to AI Platform's Model. - migrate_automl_dataset_config (~.migration_service.MigrateResourceRequest.MigrateAutomlDatasetConfig): + migrate_automl_dataset_config (google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest.MigrateAutomlDatasetConfig): Config for migrating Dataset in automl.googleapis.com to AI Platform's Dataset. - migrate_data_labeling_dataset_config (~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig): + migrate_data_labeling_dataset_config (google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest.MigrateDataLabelingDatasetConfig): Config for migrating Dataset in datalabeling.googleapis.com to AI Platform's Dataset. @@ -211,7 +231,7 @@ class MigrateDataLabelingDatasetConfig(proto.Message): Optional. Display name of the Dataset in AI Platform. System will pick a display name if unspecified. - migrate_data_labeling_annotated_dataset_configs (Sequence[~.migration_service.MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig]): + migrate_data_labeling_annotated_dataset_configs (Sequence[google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest.MigrateDataLabelingDatasetConfig.MigrateDataLabelingAnnotatedDatasetConfig]): Optional. Configs for migrating AnnotatedDataset in datalabeling.googleapis.com to AI Platform's SavedQuery. The specified @@ -271,7 +291,7 @@ class BatchMigrateResourcesResponse(proto.Message): ``MigrationService.BatchMigrateResources``. Attributes: - migrate_resource_responses (Sequence[~.migration_service.MigrateResourceResponse]): + migrate_resource_responses (Sequence[google.cloud.aiplatform_v1beta1.types.MigrateResourceResponse]): Successfully migrated resources. """ @@ -288,7 +308,7 @@ class MigrateResourceResponse(proto.Message): Migrated Dataset's resource name. model (str): Migrated Model's resource name. - migratable_resource (~.gca_migratable_resource.MigratableResource): + migratable_resource (google.cloud.aiplatform_v1beta1.types.MigratableResource): Before migration, the identifier in ml.googleapis.com, automl.googleapis.com or datalabeling.googleapis.com. @@ -308,13 +328,49 @@ class BatchMigrateResourcesOperationMetadata(proto.Message): ``MigrationService.BatchMigrateResources``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. + partial_results (Sequence[google.cloud.aiplatform_v1beta1.types.BatchMigrateResourcesOperationMetadata.PartialResult]): + Partial results that reflects the latest + migration operation progress. """ + class PartialResult(proto.Message): + r"""Represents a partial result in batch migration opreation for one + ``MigrateResourceRequest``. + + Attributes: + error (google.rpc.status_pb2.Status): + The error result of the migration request in + case of failure. + model (str): + Migrated model resource name. + dataset (str): + Migrated dataset resource name. + request (google.cloud.aiplatform_v1beta1.types.MigrateResourceRequest): + It's the same as the value in + [MigrateResourceRequest.migrate_resource_requests][]. + """ + + error = proto.Field( + proto.MESSAGE, number=2, oneof="result", message=status.Status, + ) + + model = proto.Field(proto.STRING, number=3, oneof="result") + + dataset = proto.Field(proto.STRING, number=4, oneof="result") + + request = proto.Field( + proto.MESSAGE, number=1, message="MigrateResourceRequest", + ) + generic_metadata = proto.Field( proto.MESSAGE, number=1, message=operation.GenericOperationMetadata, ) + partial_results = proto.RepeatedField( + proto.MESSAGE, number=2, message=PartialResult, + ) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model.py b/google/cloud/aiplatform_v1beta1/types/model.py index 21e8c41034..c9e1ddd68a 100644 --- a/google/cloud/aiplatform_v1beta1/types/model.py +++ b/google/cloud/aiplatform_v1beta1/types/model.py @@ -19,6 +19,7 @@ from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import env_var from google.cloud.aiplatform_v1beta1.types import explanation from google.protobuf import struct_pb2 as struct # type: ignore @@ -43,7 +44,7 @@ class Model(proto.Message): can be consist of any UTF-8 characters. description (str): The description of the Model. - predict_schemata (~.model.PredictSchemata): + predict_schemata (google.cloud.aiplatform_v1beta1.types.PredictSchemata): The schemata that describe formats of the Model's predictions and explanations as given and returned via ``PredictionService.Predict`` @@ -62,12 +63,12 @@ class Model(proto.Message): be immutable and probably different, including the URI scheme, than the one given on input. The output URI will point to a location where the user only has a read access. - metadata (~.struct.Value): + metadata (google.protobuf.struct_pb2.Value): Immutable. An additional information about the Model; the schema of the metadata can be found in ``metadata_schema``. Unset if the Model does not have any additional information. - supported_export_formats (Sequence[~.model.Model.ExportFormat]): + supported_export_formats (Sequence[google.cloud.aiplatform_v1beta1.types.Model.ExportFormat]): Output only. The formats in which this Model may be exported. If empty, this Model is not available for export. @@ -75,7 +76,7 @@ class Model(proto.Message): Output only. The resource name of the TrainingPipeline that uploaded this Model, if any. - container_spec (~.model.ModelContainerSpec): + container_spec (google.cloud.aiplatform_v1beta1.types.ModelContainerSpec): Input only. The specification of the container that is to be used when deploying this Model. The specification is ingested upon @@ -86,7 +87,7 @@ class Model(proto.Message): Immutable. The path to the directory containing the Model artifact and any of its supporting files. Not present for AutoML Models. - supported_deployment_resources_types (Sequence[~.model.Model.DeploymentResourcesType]): + supported_deployment_resources_types (Sequence[google.cloud.aiplatform_v1beta1.types.Model.DeploymentResourcesType]): Output only. When this Model is deployed, its prediction resources are described by the ``prediction_resources`` field of the @@ -136,6 +137,11 @@ class Model(proto.Message): Uses ``BigQuerySource``. + - ``file-list`` Each line of the file is the location of an + instance to process, uses ``gcs_source`` field of the + ``InputConfig`` + object. + If this Model doesn't support any of these formats it means it cannot be used with a ``BatchPredictionJob``. @@ -182,41 +188,40 @@ class Model(proto.Message): ``PredictionService.Predict`` or ``PredictionService.Explain``. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Model was uploaded into AI Platform. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this Model was most recently updated. - deployed_models (Sequence[~.deployed_model_ref.DeployedModelRef]): + deployed_models (Sequence[google.cloud.aiplatform_v1beta1.types.DeployedModelRef]): Output only. The pointers to DeployedModels created from this Model. Note that Model could have been deployed to Endpoints in different Locations. - explanation_spec (~.explanation.ExplanationSpec): - Output only. The default explanation specification for this - Model. + explanation_spec (google.cloud.aiplatform_v1beta1.types.ExplanationSpec): + The default explanation specification for this Model. - Model can be used for [requesting - explanation][google.cloud.aiplatform.v1beta1.PredictionService.Explain] - after being + The Model can be used for [requesting + explanation][PredictionService.Explain] after being ``deployed`` - iff it is populated. + iff it is populated. The Model can be used for [batch + explanation][BatchPredictionJob.generate_explanation] iff it + is populated. All fields of the explanation_spec can be overridden by ``explanation_spec`` of - ``DeployModelRequest.deployed_model``. - - This field is populated only for tabular AutoML Models. - Specifying it with - ``ModelService.UploadModel`` - is not supported. + ``DeployModelRequest.deployed_model``, + or + ``explanation_spec`` + of + ``BatchPredictionJob``. etag (str): Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens. - labels (Sequence[~.model.Model.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.Model.LabelsEntry]): The labels with user-defined metadata to organize your Models. Label keys and values can be no longer than 64 @@ -226,6 +231,10 @@ class Model(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a + Model. If set, this Model and all sub-resources + of this Model will be secured by this key. """ class DeploymentResourcesType(proto.Enum): @@ -260,7 +269,7 @@ class ExportFormat(proto.Message): - ``custom-trained`` A Model that was uploaded or trained by custom code. - exportable_contents (Sequence[~.model.Model.ExportFormat.ExportableContent]): + exportable_contents (Sequence[google.cloud.aiplatform_v1beta1.types.Model.ExportFormat.ExportableContent]): Output only. The content of this Model that may be exported. """ @@ -323,6 +332,10 @@ class ExportableContent(proto.Enum): labels = proto.MapField(proto.STRING, proto.STRING, number=17) + encryption_spec = proto.Field( + proto.MESSAGE, number=24, message=gca_encryption_spec.EncryptionSpec, + ) + class PredictSchemata(proto.Message): r"""Contains the schemata used in Model's predictions and explanations @@ -491,7 +504,7 @@ class ModelContainerSpec(proto.Message): field corresponds to the ``args`` field of the Kubernetes Containers `v1 core API `__. - env (Sequence[~.env_var.EnvVar]): + env (Sequence[google.cloud.aiplatform_v1beta1.types.EnvVar]): Immutable. List of environment variables to set in the container. After the container starts running, code running in the container can read these environment variables. @@ -524,7 +537,7 @@ class ModelContainerSpec(proto.Message): This field corresponds to the ``env`` field of the Kubernetes Containers `v1 core API `__. - ports (Sequence[~.model.Port]): + ports (Sequence[google.cloud.aiplatform_v1beta1.types.Port]): Immutable. List of ports to expose from the container. AI Platform sends any prediction requests that it receives to the first port on this list. AI Platform also sends @@ -557,9 +570,9 @@ class ModelContainerSpec(proto.Message): For example, if you set this field to ``/foo``, then when AI Platform receives a prediction request, it forwards the - request body in a POST request to the following URL on the - container: localhost:PORT/foo PORT refers to the first value - of this ``ModelContainerSpec``'s + request body in a POST request to the ``/foo`` path on the + port of your container specified by the first value of this + ``ModelContainerSpec``'s ``ports`` field. @@ -590,9 +603,9 @@ class ModelContainerSpec(proto.Message): checks `__. For example, if you set this field to ``/bar``, then AI - Platform intermittently sends a GET request to the following - URL on the container: localhost:PORT/bar PORT refers to the - first value of this ``ModelContainerSpec``'s + Platform intermittently sends a GET request to the ``/bar`` + path on the port of your container specified by the first + value of this ``ModelContainerSpec``'s ``ports`` field. diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py index b768ed978e..391bc38cf4 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation.py @@ -44,11 +44,11 @@ class ModelEvaluation(proto.Message): of this ModelEvaluation. The schema is defined as an OpenAPI 3.0.2 `Schema Object `__. - metrics (~.struct.Value): + metrics (google.protobuf.struct_pb2.Value): Output only. Evaluation metrics of the Model. The schema of the metrics is stored in ``metrics_schema_uri`` - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this ModelEvaluation was created. slice_dimensions (Sequence[str]): @@ -58,15 +58,41 @@ class ModelEvaluation(proto.Message): filter of the ``ModelService.ListModelEvaluationSlices`` request, in the form of ``slice.dimension = ``. - model_explanation (~.explanation.ModelExplanation): + model_explanation (google.cloud.aiplatform_v1beta1.types.ModelExplanation): Output only. Aggregated explanation metrics for the Model's prediction output over the data this ModelEvaluation uses. This field is populated only if the Model is evaluated with explanations, and only for AutoML tabular Models. + explanation_specs (Sequence[google.cloud.aiplatform_v1beta1.types.ModelEvaluation.ModelEvaluationExplanationSpec]): + Output only. Describes the values of + ``ExplanationSpec`` + that are used for explaining the predicted values on the + evaluated data. """ + class ModelEvaluationExplanationSpec(proto.Message): + r""" + + Attributes: + explanation_type (str): + Explanation type. + + For AutoML Image Classification models, possible values are: + + - ``image-integrated-gradients`` + - ``image-xrai`` + explanation_spec (google.cloud.aiplatform_v1beta1.types.ExplanationSpec): + Explanation spec details. + """ + + explanation_type = proto.Field(proto.STRING, number=1) + + explanation_spec = proto.Field( + proto.MESSAGE, number=2, message=explanation.ExplanationSpec, + ) + name = proto.Field(proto.STRING, number=1) metrics_schema_uri = proto.Field(proto.STRING, number=2) @@ -81,5 +107,9 @@ class ModelEvaluation(proto.Message): proto.MESSAGE, number=8, message=explanation.ModelExplanation, ) + explanation_specs = proto.RepeatedField( + proto.MESSAGE, number=9, message=ModelEvaluationExplanationSpec, + ) + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py index 1039d32c1f..2d66e29a9f 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py +++ b/google/cloud/aiplatform_v1beta1/types/model_evaluation_slice.py @@ -36,7 +36,7 @@ class ModelEvaluationSlice(proto.Message): name (str): Output only. The resource name of the ModelEvaluationSlice. - slice_ (~.model_evaluation_slice.ModelEvaluationSlice.Slice): + slice_ (google.cloud.aiplatform_v1beta1.types.ModelEvaluationSlice.Slice): Output only. The slice of the test data that is used to evaluate the Model. metrics_schema_uri (str): @@ -46,11 +46,11 @@ class ModelEvaluationSlice(proto.Message): of this ModelEvaluationSlice. The schema is defined as an OpenAPI 3.0.2 `Schema Object `__. - metrics (~.struct.Value): + metrics (google.protobuf.struct_pb2.Value): Output only. Sliced evaluation metrics of the Model. The schema of the metrics is stored in ``metrics_schema_uri`` - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Timestamp when this ModelEvaluationSlice was created. """ diff --git a/google/cloud/aiplatform_v1beta1/types/model_service.py b/google/cloud/aiplatform_v1beta1/types/model_service.py index 3cfb17ad2c..1091c71148 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_service.py @@ -59,7 +59,7 @@ class UploadModelRequest(proto.Message): Required. The resource name of the Location into which to upload the Model. Format: ``projects/{project}/locations/{location}`` - model (~.gca_model.Model): + model (google.cloud.aiplatform_v1beta1.types.Model): Required. The Model to create. """ @@ -74,7 +74,7 @@ class UploadModelOperationMetadata(proto.Message): operation. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. """ @@ -120,7 +120,25 @@ class ListModelsRequest(proto.Message): Models from. Format: ``projects/{project}/locations/{location}`` filter (str): - The standard list filter. + An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. + + - ``model`` supports = and !=. ``model`` represents the + Model ID, i.e. the last segment of the Model's [resource + name][google.cloud.aiplatform.v1beta1.Model.name]. + - ``display_name`` supports = and != + - ``labels`` supports general map functions that is: + + - ``labels.key=value`` - key:value equality + - \`labels.key:\* or labels:key - key existence + - A key including a space must be quoted. + ``labels."a key"``. + + Some examples: + + - ``model=1234`` + - ``displayName="myDisplayName"`` + - ``labels.myKey="myValue"`` page_size (int): The standard list page size. page_token (str): @@ -129,7 +147,7 @@ class ListModelsRequest(proto.Message): of the previous ``ModelService.ListModels`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -149,7 +167,7 @@ class ListModelsResponse(proto.Message): ``ModelService.ListModels`` Attributes: - models (Sequence[~.gca_model.Model]): + models (Sequence[google.cloud.aiplatform_v1beta1.types.Model]): List of Models in the requested page. next_page_token (str): A token to retrieve next page of results. Pass to @@ -171,14 +189,13 @@ class UpdateModelRequest(proto.Message): ``ModelService.UpdateModel``. Attributes: - model (~.gca_model.Model): + model (google.cloud.aiplatform_v1beta1.types.Model): Required. The Model which replaces the resource on the server. - update_mask (~.field_mask.FieldMask): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. For the ``FieldMask`` definition, see - - [FieldMask](https://developers.google.com/protocol-buffers/docs/reference/google.protobuf#fieldmask). + `FieldMask `__. """ model = proto.Field(proto.MESSAGE, number=1, message=gca_model.Model,) @@ -208,7 +225,7 @@ class ExportModelRequest(proto.Message): name (str): Required. The resource name of the Model to export. Format: ``projects/{project}/locations/{location}/models/{model}`` - output_config (~.model_service.ExportModelRequest.OutputConfig): + output_config (google.cloud.aiplatform_v1beta1.types.ExportModelRequest.OutputConfig): Required. The desired output location and configuration. """ @@ -223,23 +240,23 @@ class OutputConfig(proto.Message): supports][google.cloud.aiplatform.v1beta1.Model.supported_export_formats]. If no value is provided here, then the first from the list of the Model's supported formats is used by default. - artifact_destination (~.io.GcsDestination): - The Google Cloud Storage location where the Model artifact - is to be written to. Under the directory given as the - destination a new one with name + artifact_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): + The Cloud Storage location where the Model artifact is to be + written to. Under the directory given as the destination a + new one with name "``model-export--``", where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format, will be created. Inside, the Model and any of its supporting files will be written. This field should only be - set when - [Models.supported_export_formats.exportable_contents] - contains ARTIFACT. - image_destination (~.io.ContainerRegistryDestination): + set when the ``exportableContent`` field of the + [Model.supported_export_formats] object contains + ``ARTIFACT``. + image_destination (google.cloud.aiplatform_v1beta1.types.ContainerRegistryDestination): The Google Container Registry or Artifact Registry uri where the Model container image will be copied to. This field - should only be set when - [Models.supported_export_formats.exportable_contents] - contains IMAGE. + should only be set when the ``exportableContent`` field of + the [Model.supported_export_formats] object contains + ``IMAGE``. """ export_format_id = proto.Field(proto.STRING, number=1) @@ -263,9 +280,9 @@ class ExportModelOperationMetadata(proto.Message): operation. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. - output_info (~.model_service.ExportModelOperationMetadata.OutputInfo): + output_info (google.cloud.aiplatform_v1beta1.types.ExportModelOperationMetadata.OutputInfo): Output only. Information further describing the output of this Model export. """ @@ -338,7 +355,7 @@ class ListModelEvaluationsRequest(proto.Message): of the previous ``ModelService.ListModelEvaluations`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -358,7 +375,7 @@ class ListModelEvaluationsResponse(proto.Message): ``ModelService.ListModelEvaluations``. Attributes: - model_evaluations (Sequence[~.model_evaluation.ModelEvaluation]): + model_evaluations (Sequence[google.cloud.aiplatform_v1beta1.types.ModelEvaluation]): List of ModelEvaluations in the requested page. next_page_token (str): @@ -415,7 +432,7 @@ class ListModelEvaluationSlicesRequest(proto.Message): of the previous ``ModelService.ListModelEvaluationSlices`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -435,7 +452,7 @@ class ListModelEvaluationSlicesResponse(proto.Message): ``ModelService.ListModelEvaluationSlices``. Attributes: - model_evaluation_slices (Sequence[~.model_evaluation_slice.ModelEvaluationSlice]): + model_evaluation_slices (Sequence[google.cloud.aiplatform_v1beta1.types.ModelEvaluationSlice]): List of ModelEvaluations in the requested page. next_page_token (str): diff --git a/google/cloud/aiplatform_v1beta1/types/operation.py b/google/cloud/aiplatform_v1beta1/types/operation.py index 68fb0daead..90565867e8 100644 --- a/google/cloud/aiplatform_v1beta1/types/operation.py +++ b/google/cloud/aiplatform_v1beta1/types/operation.py @@ -32,16 +32,16 @@ class GenericOperationMetadata(proto.Message): r"""Generic Metadata shared by all operations. Attributes: - partial_failures (Sequence[~.status.Status]): + partial_failures (Sequence[google.rpc.status_pb2.Status]): Output only. Partial failures encountered. E.g. single files that couldn't be read. This field should never exceed 20 entries. Status details field will contain standard GCP error details. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the operation was created. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the operation was updated for the last time. If the operation has finished (successfully or not), this is the @@ -61,7 +61,7 @@ class DeleteOperationMetadata(proto.Message): r"""Details of operations that perform deletes of any entities. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The common part of the operation metadata. """ diff --git a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py index 9f0856732d..32eec8489c 100644 --- a/google/cloud/aiplatform_v1beta1/types/pipeline_service.py +++ b/google/cloud/aiplatform_v1beta1/types/pipeline_service.py @@ -46,7 +46,7 @@ class CreateTrainingPipelineRequest(proto.Message): Required. The resource name of the Location to create the TrainingPipeline in. Format: ``projects/{project}/locations/{location}`` - training_pipeline (~.gca_training_pipeline.TrainingPipeline): + training_pipeline (google.cloud.aiplatform_v1beta1.types.TrainingPipeline): Required. The TrainingPipeline to create. """ @@ -104,7 +104,7 @@ class ListTrainingPipelinesRequest(proto.Message): of the previous ``PipelineService.ListTrainingPipelines`` call. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. """ @@ -124,7 +124,7 @@ class ListTrainingPipelinesResponse(proto.Message): ``PipelineService.ListTrainingPipelines`` Attributes: - training_pipelines (Sequence[~.gca_training_pipeline.TrainingPipeline]): + training_pipelines (Sequence[google.cloud.aiplatform_v1beta1.types.TrainingPipeline]): List of TrainingPipelines in the requested page. next_page_token (str): diff --git a/google/cloud/aiplatform_v1beta1/types/prediction_service.py b/google/cloud/aiplatform_v1beta1/types/prediction_service.py index b000f88bf8..f7abe9e3e2 100644 --- a/google/cloud/aiplatform_v1beta1/types/prediction_service.py +++ b/google/cloud/aiplatform_v1beta1/types/prediction_service.py @@ -42,7 +42,7 @@ class PredictRequest(proto.Message): Required. The name of the Endpoint requested to serve the prediction. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - instances (Sequence[~.struct.Value]): + instances (Sequence[google.protobuf.struct_pb2.Value]): Required. The instances that are the input to the prediction call. A DeployedModel may have an upper limit on the number of instances it supports per request, and when it is @@ -54,7 +54,7 @@ class PredictRequest(proto.Message): [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. - parameters (~.struct.Value): + parameters (google.protobuf.struct_pb2.Value): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's @@ -75,7 +75,7 @@ class PredictResponse(proto.Message): ``PredictionService.Predict``. Attributes: - predictions (Sequence[~.struct.Value]): + predictions (Sequence[google.protobuf.struct_pb2.Value]): The predictions that are the output of the predictions call. The schema of any single prediction may be specified via Endpoint's DeployedModels' [Model's @@ -101,7 +101,7 @@ class ExplainRequest(proto.Message): Required. The name of the Endpoint requested to serve the explanation. Format: ``projects/{project}/locations/{location}/endpoints/{endpoint}`` - instances (Sequence[~.struct.Value]): + instances (Sequence[google.protobuf.struct_pb2.Value]): Required. The instances that are the input to the explanation call. A DeployedModel may have an upper limit on the number of instances it supports per request, and when it @@ -113,13 +113,24 @@ class ExplainRequest(proto.Message): [Model's][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``instance_schema_uri``. - parameters (~.struct.Value): + parameters (google.protobuf.struct_pb2.Value): The parameters that govern the prediction. The schema of the parameters may be specified via Endpoint's DeployedModels' [Model's ][google.cloud.aiplatform.v1beta1.DeployedModel.model] [PredictSchemata's][google.cloud.aiplatform.v1beta1.Model.predict_schemata] ``parameters_schema_uri``. + explanation_spec_override (google.cloud.aiplatform_v1beta1.types.ExplanationSpecOverride): + If specified, overrides the + ``explanation_spec`` + of the DeployedModel. Can be used for explaining prediction + results with different configurations, such as: + + - Explaining top-5 predictions results as opposed to top-1; + - Increasing path count or step count of the attribution + methods to reduce approximate errors; + - Using different baselines for explaining the prediction + results. deployed_model_id (str): If specified, this ExplainRequest will be served by the chosen DeployedModel, overriding @@ -132,6 +143,10 @@ class ExplainRequest(proto.Message): parameters = proto.Field(proto.MESSAGE, number=4, message=struct.Value,) + explanation_spec_override = proto.Field( + proto.MESSAGE, number=5, message=explanation.ExplanationSpecOverride, + ) + deployed_model_id = proto.Field(proto.STRING, number=3) @@ -140,7 +155,7 @@ class ExplainResponse(proto.Message): ``PredictionService.Explain``. Attributes: - explanations (Sequence[~.explanation.Explanation]): + explanations (Sequence[google.cloud.aiplatform_v1beta1.types.Explanation]): The explanations of the Model's ``PredictResponse.predictions``. @@ -150,7 +165,7 @@ class ExplainResponse(proto.Message): deployed_model_id (str): ID of the Endpoint's DeployedModel that served this explanation. - predictions (Sequence[~.struct.Value]): + predictions (Sequence[google.protobuf.struct_pb2.Value]): The predictions that are the output of the predictions call. Same as ``PredictResponse.predictions``. diff --git a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py index 724f7165a6..9d0241080b 100644 --- a/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py +++ b/google/cloud/aiplatform_v1beta1/types/specialist_pool_service.py @@ -47,7 +47,7 @@ class CreateSpecialistPoolRequest(proto.Message): Required. The parent Project name for the new SpecialistPool. The form is ``projects/{project}/locations/{location}``. - specialist_pool (~.gca_specialist_pool.SpecialistPool): + specialist_pool (google.cloud.aiplatform_v1beta1.types.SpecialistPool): Required. The SpecialistPool to create. """ @@ -63,7 +63,7 @@ class CreateSpecialistPoolOperationMetadata(proto.Message): ``SpecialistPoolService.CreateSpecialistPool``. Attributes: - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ @@ -103,7 +103,7 @@ class ListSpecialistPoolsRequest(proto.Message): of the previous ``SpecialistPoolService.ListSpecialistPools`` call. Return first page if empty. - read_mask (~.field_mask.FieldMask): + read_mask (google.protobuf.field_mask_pb2.FieldMask): Mask specifying which fields to read. FieldMask represents a set of """ @@ -122,7 +122,7 @@ class ListSpecialistPoolsResponse(proto.Message): ``SpecialistPoolService.ListSpecialistPools``. Attributes: - specialist_pools (Sequence[~.gca_specialist_pool.SpecialistPool]): + specialist_pools (Sequence[google.cloud.aiplatform_v1beta1.types.SpecialistPool]): A list of SpecialistPools that matches the specified filter in the request. next_page_token (str): @@ -166,10 +166,10 @@ class UpdateSpecialistPoolRequest(proto.Message): ``SpecialistPoolService.UpdateSpecialistPool``. Attributes: - specialist_pool (~.gca_specialist_pool.SpecialistPool): + specialist_pool (google.cloud.aiplatform_v1beta1.types.SpecialistPool): Required. The SpecialistPool which replaces the resource on the server. - update_mask (~.field_mask.FieldMask): + update_mask (google.protobuf.field_mask_pb2.FieldMask): Required. The update mask applies to the resource. """ @@ -191,7 +191,7 @@ class UpdateSpecialistPoolOperationMetadata(proto.Message): specialists are being added. Format: ``projects/{project_id}/locations/{location_id}/specialistPools/{specialist_pool}`` - generic_metadata (~.operation.GenericOperationMetadata): + generic_metadata (google.cloud.aiplatform_v1beta1.types.GenericOperationMetadata): The operation generic information. """ diff --git a/google/cloud/aiplatform_v1beta1/types/study.py b/google/cloud/aiplatform_v1beta1/types/study.py index 2d6f4ae8c3..4f8b972746 100644 --- a/google/cloud/aiplatform_v1beta1/types/study.py +++ b/google/cloud/aiplatform_v1beta1/types/study.py @@ -37,16 +37,16 @@ class Trial(proto.Message): id (str): Output only. The identifier of the Trial assigned by the service. - state (~.study.Trial.State): + state (google.cloud.aiplatform_v1beta1.types.Trial.State): Output only. The detailed state of the Trial. - parameters (Sequence[~.study.Trial.Parameter]): + parameters (Sequence[google.cloud.aiplatform_v1beta1.types.Trial.Parameter]): Output only. The parameters of the Trial. - final_measurement (~.study.Measurement): + final_measurement (google.cloud.aiplatform_v1beta1.types.Measurement): Output only. The final measurement containing the objective value. - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the Trial was started. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the Trial's status changed to ``SUCCEEDED`` or ``INFEASIBLE``. custom_job (str): @@ -72,7 +72,7 @@ class Parameter(proto.Message): Output only. The ID of the parameter. The parameter should be defined in [StudySpec's Parameters][google.cloud.aiplatform.v1beta1.StudySpec.parameters]. - value (~.struct.Value): + value (google.protobuf.struct_pb2.Value): Output only. The value of the parameter. ``number_value`` will be set if a parameter defined in StudySpec is in type 'INTEGER', 'DOUBLE' or 'DISCRETE'. ``string_value`` will be @@ -103,12 +103,20 @@ class StudySpec(proto.Message): r"""Represents specification of a Study. Attributes: - metrics (Sequence[~.study.StudySpec.MetricSpec]): + metrics (Sequence[google.cloud.aiplatform_v1beta1.types.StudySpec.MetricSpec]): Required. Metric specs for the Study. - parameters (Sequence[~.study.StudySpec.ParameterSpec]): + parameters (Sequence[google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec]): Required. The set of parameters to tune. - algorithm (~.study.StudySpec.Algorithm): + algorithm (google.cloud.aiplatform_v1beta1.types.StudySpec.Algorithm): The search algorithm specified for the Study. + observation_noise (google.cloud.aiplatform_v1beta1.types.StudySpec.ObservationNoise): + The observation noise level of the study. + Currently only supported by the Vizier service. + Not supported by HyperparamterTuningJob or + TrainingPipeline. + measurement_selection_type (google.cloud.aiplatform_v1beta1.types.StudySpec.MeasurementSelectionType): + Describe which measurement selection type + will be used """ class Algorithm(proto.Enum): @@ -117,6 +125,33 @@ class Algorithm(proto.Enum): GRID_SEARCH = 2 RANDOM_SEARCH = 3 + class ObservationNoise(proto.Enum): + r"""Describes the noise level of the repeated observations. + "Noisy" means that the repeated observations with the same Trial + parameters may lead to different metric evaluations. + """ + OBSERVATION_NOISE_UNSPECIFIED = 0 + LOW = 1 + HIGH = 2 + + class MeasurementSelectionType(proto.Enum): + r"""This indicates which measurement to use if/when the service + automatically selects the final measurement from previously reported + intermediate measurements. Choose this based on two considerations: + A) Do you expect your measurements to monotonically improve? If so, + choose LAST_MEASUREMENT. On the other hand, if you're in a situation + where your system can "over-train" and you expect the performance to + get better for a while but then start declining, choose + BEST_MEASUREMENT. B) Are your measurements significantly noisy + and/or irreproducible? If so, BEST_MEASUREMENT will tend to be + over-optimistic, and it may be better to choose LAST_MEASUREMENT. If + both or neither of (A) and (B) apply, it doesn't matter which + selection type is chosen. + """ + MEASUREMENT_SELECTION_TYPE_UNSPECIFIED = 0 + LAST_MEASUREMENT = 1 + BEST_MEASUREMENT = 2 + class MetricSpec(proto.Message): r"""Represents a metric to optimize. @@ -125,7 +160,7 @@ class MetricSpec(proto.Message): Required. The ID of the metric. Must not contain whitespaces and must be unique amongst all MetricSpecs. - goal (~.study.StudySpec.MetricSpec.GoalType): + goal (google.cloud.aiplatform_v1beta1.types.StudySpec.MetricSpec.GoalType): Required. The optimization goal of the metric. """ @@ -144,22 +179,22 @@ class ParameterSpec(proto.Message): r"""Represents a single parameter to optimize. Attributes: - double_value_spec (~.study.StudySpec.ParameterSpec.DoubleValueSpec): + double_value_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.DoubleValueSpec): The value spec for a 'DOUBLE' parameter. - integer_value_spec (~.study.StudySpec.ParameterSpec.IntegerValueSpec): + integer_value_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.IntegerValueSpec): The value spec for an 'INTEGER' parameter. - categorical_value_spec (~.study.StudySpec.ParameterSpec.CategoricalValueSpec): + categorical_value_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.CategoricalValueSpec): The value spec for a 'CATEGORICAL' parameter. - discrete_value_spec (~.study.StudySpec.ParameterSpec.DiscreteValueSpec): + discrete_value_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.DiscreteValueSpec): The value spec for a 'DISCRETE' parameter. parameter_id (str): Required. The ID of the parameter. Must not contain whitespaces and must be unique amongst all ParameterSpecs. - scale_type (~.study.StudySpec.ParameterSpec.ScaleType): + scale_type (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.ScaleType): How the parameter should be scaled. Leave unset for ``CATEGORICAL`` parameters. - conditional_parameter_specs (Sequence[~.study.StudySpec.ParameterSpec.ConditionalParameterSpec]): + conditional_parameter_specs (Sequence[google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.ConditionalParameterSpec]): A conditional parameter node is active if the parameter's value matches the conditional node's parent_value_condition. @@ -236,16 +271,16 @@ class ConditionalParameterSpec(proto.Message): parameter. Attributes: - parent_discrete_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition): + parent_discrete_values (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.DiscreteValueCondition): The spec for matching values from a parent parameter of ``DISCRETE`` type. - parent_int_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition): + parent_int_values (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.IntValueCondition): The spec for matching values from a parent parameter of ``INTEGER`` type. - parent_categorical_values (~.study.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition): + parent_categorical_values (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec.ConditionalParameterSpec.CategoricalValueCondition): The spec for matching values from a parent parameter of ``CATEGORICAL`` type. - parameter_spec (~.study.StudySpec.ParameterSpec): + parameter_spec (google.cloud.aiplatform_v1beta1.types.StudySpec.ParameterSpec): Required. The spec for a conditional parameter. """ @@ -362,6 +397,12 @@ class CategoricalValueCondition(proto.Message): algorithm = proto.Field(proto.ENUM, number=3, enum=Algorithm,) + observation_noise = proto.Field(proto.ENUM, number=6, enum=ObservationNoise,) + + measurement_selection_type = proto.Field( + proto.ENUM, number=7, enum=MeasurementSelectionType, + ) + class Measurement(proto.Message): r"""A message representing a Measurement of a Trial. A @@ -373,7 +414,7 @@ class Measurement(proto.Message): Output only. The number of steps the machine learning model has been trained for. Must be non-negative. - metrics (Sequence[~.study.Measurement.Metric]): + metrics (Sequence[google.cloud.aiplatform_v1beta1.types.Measurement.Metric]): Output only. A list of metrics got by evaluating the objective functions using suggested Parameter values. diff --git a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py index f1f0debaf9..6175a12e96 100644 --- a/google/cloud/aiplatform_v1beta1/types/training_pipeline.py +++ b/google/cloud/aiplatform_v1beta1/types/training_pipeline.py @@ -18,6 +18,7 @@ import proto # type: ignore +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import model from google.cloud.aiplatform_v1beta1.types import pipeline_state @@ -54,7 +55,7 @@ class TrainingPipeline(proto.Message): display_name (str): Required. The user-defined name of this TrainingPipeline. - input_data_config (~.training_pipeline.InputDataConfig): + input_data_config (google.cloud.aiplatform_v1beta1.types.InputDataConfig): Specifies AI Platform owned input data that may be used for training the Model. The TrainingPipeline's ``training_task_definition`` @@ -77,12 +78,12 @@ class TrainingPipeline(proto.Message): than the one given on input. The output URI will point to a location where the user only has a read access. - training_task_inputs (~.struct.Value): + training_task_inputs (google.protobuf.struct_pb2.Value): Required. The training task's parameter(s), as specified in the ``training_task_definition``'s ``inputs``. - training_task_metadata (~.struct.Value): + training_task_metadata (google.protobuf.struct_pb2.Value): Output only. The metadata information as specified in the ``training_task_definition``'s ``metadata``. This metadata is an auxiliary runtime and @@ -91,10 +92,10 @@ class TrainingPipeline(proto.Message): best effort basis. Only present if the pipeline's ``training_task_definition`` contains ``metadata`` object. - model_to_upload (~.model.Model): + model_to_upload (google.cloud.aiplatform_v1beta1.types.Model): Describes the Model that may be uploaded (via - [ModelService.UploadMode][]) by this TrainingPipeline. The - TrainingPipeline's + ``ModelService.UploadModel``) + by this TrainingPipeline. The TrainingPipeline's ``training_task_definition`` should make clear whether this Model description should be populated, and if there are any special requirements @@ -111,26 +112,26 @@ class TrainingPipeline(proto.Message): resource ``name`` is populated. The Model is always uploaded into the Project and Location in which this pipeline is. - state (~.pipeline_state.PipelineState): + state (google.cloud.aiplatform_v1beta1.types.PipelineState): Output only. The detailed state of the pipeline. - error (~.status.Status): + error (google.rpc.status_pb2.Status): Output only. Only populated when the pipeline's state is ``PIPELINE_STATE_FAILED`` or ``PIPELINE_STATE_CANCELLED``. - create_time (~.timestamp.Timestamp): + create_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the TrainingPipeline was created. - start_time (~.timestamp.Timestamp): + start_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the TrainingPipeline for the first time entered the ``PIPELINE_STATE_RUNNING`` state. - end_time (~.timestamp.Timestamp): + end_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the TrainingPipeline entered any of the following states: ``PIPELINE_STATE_SUCCEEDED``, ``PIPELINE_STATE_FAILED``, ``PIPELINE_STATE_CANCELLED``. - update_time (~.timestamp.Timestamp): + update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. Time when the TrainingPipeline was most recently updated. - labels (Sequence[~.training_pipeline.TrainingPipeline.LabelsEntry]): + labels (Sequence[google.cloud.aiplatform_v1beta1.types.TrainingPipeline.LabelsEntry]): The labels with user-defined metadata to organize TrainingPipelines. Label keys and values can be no longer than 64 @@ -140,6 +141,14 @@ class TrainingPipeline(proto.Message): are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Customer-managed encryption key spec for a TrainingPipeline. + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if + ``model_to_upload`` + is not set separately. """ name = proto.Field(proto.STRING, number=1) @@ -170,38 +179,42 @@ class TrainingPipeline(proto.Message): labels = proto.MapField(proto.STRING, proto.STRING, number=15) + encryption_spec = proto.Field( + proto.MESSAGE, number=18, message=gca_encryption_spec.EncryptionSpec, + ) + class InputDataConfig(proto.Message): r"""Specifies AI Platform owned input data to be used for training, and possibly evaluating, the Model. Attributes: - fraction_split (~.training_pipeline.FractionSplit): + fraction_split (google.cloud.aiplatform_v1beta1.types.FractionSplit): Split based on fractions defining the size of each set. - filter_split (~.training_pipeline.FilterSplit): + filter_split (google.cloud.aiplatform_v1beta1.types.FilterSplit): Split based on the provided filters for each set. - predefined_split (~.training_pipeline.PredefinedSplit): + predefined_split (google.cloud.aiplatform_v1beta1.types.PredefinedSplit): Supported only for tabular Datasets. Split based on a predefined key. - timestamp_split (~.training_pipeline.TimestampSplit): + timestamp_split (google.cloud.aiplatform_v1beta1.types.TimestampSplit): Supported only for tabular Datasets. Split based on the timestamp of the input data pieces. - gcs_destination (~.io.GcsDestination): - The Google Cloud Storage location where the training data is - to be written to. In the given directory a new directory - will be created with name: + gcs_destination (google.cloud.aiplatform_v1beta1.types.GcsDestination): + The Cloud Storage location where the training data is to be + written to. In the given directory a new directory is + created with name: ``dataset---`` where timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 - format. All training input data will be written into that + format. All training input data is written into that directory. - The AI Platform environment variables representing Google - Cloud Storage data URIs will always be represented in the - Google Cloud Storage wildcard format to support sharded - data. e.g.: "gs://.../training-*.jsonl" + The AI Platform environment variables representing Cloud + Storage data URIs are represented in the Cloud Storage + wildcard format to support sharded data. e.g.: + "gs://.../training-*.jsonl" - AIP_DATA_FORMAT = "jsonl" for non-tabular data, "csv" for tabular data @@ -216,14 +229,17 @@ class InputDataConfig(proto.Message): - AIP_TEST_DATA_URI = "gcs_destination/dataset---/test-*.${AIP_DATA_FORMAT}". - bigquery_destination (~.io.BigQueryDestination): + bigquery_destination (google.cloud.aiplatform_v1beta1.types.BigQueryDestination): + Only applicable to custom training with tabular Dataset with + BigQuery source. + The BigQuery project location where the training data is to be written to. In the given project a new dataset is created with name ``dataset___`` where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All - training input data will be written into that dataset. In - the dataset three tables will be created, ``training``, + training input data is written into that dataset. In the + dataset three tables are created, ``training``, ``validation`` and ``test``. - AIP_DATA_FORMAT = "bigquery". @@ -247,7 +263,7 @@ class InputDataConfig(proto.Message): For tabular Datasets, all their data is exported to training, to pick and choose from. annotations_filter (str): - Only applicable to Datasets that have DataItems and + Applicable only to Datasets that have DataItems and Annotations. A filter on Annotations of the Dataset. Only Annotations @@ -261,13 +277,15 @@ class InputDataConfig(proto.Message): may be used, but note here it filters across all Annotations of the Dataset, and not just within a single DataItem. annotation_schema_uri (str): - Only applicable to custom training. + Applicable only to custom training with Datasets that have + DataItems and Annotations. - Google Cloud Storage URI points to a YAML file describing + Cloud Storage URI that points to a YAML file describing the annotation schema. The schema is defined as an OpenAPI 3.0.2 - [Schema Object](https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#schema-object) The schema files - that can be used here are found in - gs://google-cloud-aiplatform/schema/dataset/annotation/, + `Schema + Object `__. The + schema files that can be used here are found in + gs://google-cloud-aiplatform/schema/dataset/annotation/ , note that the chosen schema must be consistent with ``metadata`` of the Dataset specified by @@ -323,8 +341,8 @@ class FractionSplit(proto.Message): ``validation_fraction`` and ``test_fraction`` may optionally be provided, they must sum to up to 1. If the provided ones sum to less than 1, the remainder is assigned to sets as decided by AI Platform. - If none of the fractions are set, by default roughly 80% of data - will be used for training, 10% for validation, and 10% for test. + If none of the fractions are set, by default roughly 80% of data is + used for training, 10% for validation, and 10% for test. Attributes: training_fraction (float): @@ -353,6 +371,8 @@ class FilterSplit(proto.Message): If any of the filters in this message are to match nothing, then they can be set as '-' (the minus sign). + Supported only for unstructured Datasets. + Attributes: training_filter (str): Required. A filter on DataItems of the Dataset. DataItems @@ -360,27 +380,27 @@ class FilterSplit(proto.Message): with same syntax as the one used in ``DatasetService.ListDataItems`` may be used. If a single DataItem is matched by more than - one of the FilterSplit filters, then it will be assigned to - the first set that applies to it in the training, - validation, test order. + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. validation_filter (str): Required. A filter on DataItems of the Dataset. DataItems that match this filter are used to validate the Model. A filter with same syntax as the one used in ``DatasetService.ListDataItems`` may be used. If a single DataItem is matched by more than - one of the FilterSplit filters, then it will be assigned to - the first set that applies to it in the training, - validation, test order. + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. test_filter (str): Required. A filter on DataItems of the Dataset. DataItems that match this filter are used to test the Model. A filter with same syntax as the one used in ``DatasetService.ListDataItems`` may be used. If a single DataItem is matched by more than - one of the FilterSplit filters, then it will be assigned to - the first set that applies to it in the training, - validation, test order. + one of the FilterSplit filters, then it is assigned to the + first set that applies to it in the training, validation, + test order. """ training_filter = proto.Field(proto.STRING, number=1) diff --git a/noxfile.py b/noxfile.py index 87765339b5..6fd80b539a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -30,6 +30,17 @@ SYSTEM_TEST_PYTHON_VERSIONS = ["3.8"] UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"] +# 'docfx' is excluded since it only needs to run in 'docs-presubmit' +nox.options.sessions = [ + "unit", + "system", + "cover", + "lint", + "lint_setup_py", + "blacken", + "docs", +] + @nox.session(python=DEFAULT_PYTHON_VERSION) def lint(session): @@ -75,6 +86,7 @@ def default(session): session.install( "mock", "pytest", "pytest-cov", ) + session.install("-e", ".") # Run py.test against the unit tests. diff --git a/samples/snippets/create_batch_prediction_job_bigquery_sample.py b/samples/snippets/create_batch_prediction_job_bigquery_sample.py index e050976cb3..18429a3828 100644 --- a/samples/snippets/create_batch_prediction_job_bigquery_sample.py +++ b/samples/snippets/create_batch_prediction_job_bigquery_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_create_batch_prediction_job_bigquery_sample] -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 from google.protobuf import json_format from google.protobuf.struct_pb2 import Value @@ -33,7 +33,7 @@ def create_batch_prediction_job_bigquery_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.JobServiceClient(client_options=client_options) + client = aiplatform_v1beta1.JobServiceClient(client_options=client_options) model_parameters_dict = {} model_parameters = json_format.ParseDict(model_parameters_dict, Value()) diff --git a/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py index 62eee08856..4f7319ea62 100644 --- a/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py +++ b/samples/snippets/create_batch_prediction_job_tabular_forecasting_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_create_batch_prediction_job_tabular_forecasting_sample] -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 def create_batch_prediction_job_tabular_forecasting_sample( @@ -30,7 +30,7 @@ def create_batch_prediction_job_tabular_forecasting_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.JobServiceClient(client_options=client_options) + client = aiplatform_v1beta1.JobServiceClient(client_options=client_options) batch_prediction_job = { "display_name": display_name, # Format: 'projects/{project}/locations/{location}/models/{model_id}' diff --git a/samples/snippets/explain_tabular_sample.py b/samples/snippets/explain_tabular_sample.py index cde1ec941e..38f9736cbc 100644 --- a/samples/snippets/explain_tabular_sample.py +++ b/samples/snippets/explain_tabular_sample.py @@ -15,7 +15,7 @@ # [START aiplatform_explain_tabular_sample] from typing import Dict -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 from google.protobuf import json_format from google.protobuf.struct_pb2 import Value @@ -31,7 +31,7 @@ def explain_tabular_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.PredictionServiceClient(client_options=client_options) + client = aiplatform_v1beta1.PredictionServiceClient(client_options=client_options) # The format of each instance should conform to the deployed model's prediction input schema. instance = json_format.ParseDict(instance_dict, Value()) instances = [instance] diff --git a/samples/snippets/export_model_tabular_classification_sample.py b/samples/snippets/export_model_tabular_classification_sample.py index 596adb8a64..f5c0101b16 100644 --- a/samples/snippets/export_model_tabular_classification_sample.py +++ b/samples/snippets/export_model_tabular_classification_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_export_model_tabular_classification_sample] -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 def export_model_tabular_classification_sample( @@ -28,7 +28,7 @@ def export_model_tabular_classification_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.ModelServiceClient(client_options=client_options) + client = aiplatform_v1beta1.ModelServiceClient(client_options=client_options) gcs_destination = {"output_uri_prefix": gcs_destination_output_uri_prefix} output_config = { "artifact_destination": gcs_destination, diff --git a/samples/snippets/upload_model_explain_image_managed_container_sample.py b/samples/snippets/upload_model_explain_image_managed_container_sample.py index 8fddfc5a31..ab6b0dcf2d 100644 --- a/samples/snippets/upload_model_explain_image_managed_container_sample.py +++ b/samples/snippets/upload_model_explain_image_managed_container_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_upload_model_explain_image_managed_container_sample] -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 def upload_model_explain_image_managed_container_sample( @@ -31,19 +31,19 @@ def upload_model_explain_image_managed_container_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.ModelServiceClient(client_options=client_options) + client = aiplatform_v1beta1.ModelServiceClient(client_options=client_options) # Container specification for deploying the model container_spec = {"image_uri": container_spec_image_uri, "command": [], "args": []} # The explainabilty method and corresponding parameters - parameters = aiplatform.gapic.ExplanationParameters( + parameters = aiplatform_v1beta1.ExplanationParameters( {"xrai_attribution": {"step_count": 1}} ) # The input tensor for feature attribution to the output # For single input model, y = f(x), this will be the serving input layer. - input_metadata = aiplatform.gapic.ExplanationMetadata.InputMetadata( + input_metadata = aiplatform_v1beta1.ExplanationMetadata.InputMetadata( { "input_tensor_name": input_tensor_name, # Input is image data @@ -53,21 +53,21 @@ def upload_model_explain_image_managed_container_sample( # The output tensor to explain # For single output model, y = f(x), this will be the serving output layer. - output_metadata = aiplatform.gapic.ExplanationMetadata.OutputMetadata( + output_metadata = aiplatform_v1beta1.ExplanationMetadata.OutputMetadata( {"output_tensor_name": output_tensor_name} ) # Assemble the explanation metadata - metadata = aiplatform.gapic.ExplanationMetadata( + metadata = aiplatform_v1beta1.ExplanationMetadata( inputs={"image": input_metadata}, outputs={"prediction": output_metadata} ) # Assemble the explanation specification - explanation_spec = aiplatform.gapic.ExplanationSpec( + explanation_spec = aiplatform_v1beta1.ExplanationSpec( parameters=parameters, metadata=metadata ) - model = aiplatform.gapic.Model( + model = aiplatform_v1beta1.Model( display_name=display_name, # The Cloud Storage location of the custom model artifact_uri=artifact_uri, diff --git a/samples/snippets/upload_model_explain_tabular_managed_container_sample.py b/samples/snippets/upload_model_explain_tabular_managed_container_sample.py index 5449987bd8..36e2b6dbe6 100644 --- a/samples/snippets/upload_model_explain_tabular_managed_container_sample.py +++ b/samples/snippets/upload_model_explain_tabular_managed_container_sample.py @@ -13,7 +13,7 @@ # limitations under the License. # [START aiplatform_upload_model_explain_tabular_managed_container_sample] -from google.cloud import aiplatform +from google.cloud import aiplatform_v1beta1 def upload_model_explain_tabular_managed_container_sample( @@ -32,19 +32,19 @@ def upload_model_explain_tabular_managed_container_sample( client_options = {"api_endpoint": api_endpoint} # Initialize client that will be used to create and send requests. # This client only needs to be created once, and can be reused for multiple requests. - client = aiplatform.gapic.ModelServiceClient(client_options=client_options) + client = aiplatform_v1beta1.ModelServiceClient(client_options=client_options) # Container specification for deploying the model container_spec = {"image_uri": container_spec_image_uri, "command": [], "args": []} # The explainabilty method and corresponding parameters - parameters = aiplatform.gapic.ExplanationParameters( + parameters = aiplatform_v1beta1.ExplanationParameters( {"xrai_attribution": {"step_count": 1}} ) # The input tensor for feature attribution to the output # For single input model, y = f(x), this will be the serving input layer. - input_metadata = aiplatform.gapic.ExplanationMetadata.InputMetadata( + input_metadata = aiplatform_v1beta1.ExplanationMetadata.InputMetadata( { "input_tensor_name": input_tensor_name, # Input is tabular data @@ -57,21 +57,21 @@ def upload_model_explain_tabular_managed_container_sample( # The output tensor to explain # For single output model, y = f(x), this will be the serving output layer. - output_metadata = aiplatform.gapic.ExplanationMetadata.OutputMetadata( + output_metadata = aiplatform_v1beta1.ExplanationMetadata.OutputMetadata( {"output_tensor_name": output_tensor_name} ) # Assemble the explanation metadata - metadata = aiplatform.gapic.ExplanationMetadata( + metadata = aiplatform_v1beta1.ExplanationMetadata( inputs={"features": input_metadata}, outputs={"prediction": output_metadata} ) # Assemble the explanation specification - explanation_spec = aiplatform.gapic.ExplanationSpec( + explanation_spec = aiplatform_v1beta1.ExplanationSpec( parameters=parameters, metadata=metadata ) - model = aiplatform.gapic.Model( + model = aiplatform_v1beta1.Model( display_name=display_name, # The Cloud Storage location of the custom model artifact_uri=artifact_uri, diff --git a/synth.metadata b/synth.metadata index b39f24bbb9..5167eaca25 100644 --- a/synth.metadata +++ b/synth.metadata @@ -4,21 +4,29 @@ "git": { "name": ".", "remote": "https://github.com/googleapis/python-aiplatform.git", - "sha": "688a06ff0bcc291cb63225787b7083e0b96b3615" + "sha": "fd36abe48afa3bd4d95f152b97a65613cc2ff23c" + } + }, + { + "git": { + "name": "googleapis", + "remote": "https://github.com/googleapis/googleapis.git", + "sha": "7f6e0d54743dcb86c618b2f78aac2d51e02834b5", + "internalRef": "355883280" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "f94318521f63085b9ccb43d42af89f153fb39f15" + "sha": "692715c0f23a7bb3bfbbaa300f7620ddfa8c47e5" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "f94318521f63085b9ccb43d42af89f153fb39f15" + "sha": "692715c0f23a7bb3bfbbaa300f7620ddfa8c47e5" } } ], @@ -31,6 +39,15 @@ "language": "python", "generator": "bazel" } + }, + { + "client": { + "source": "googleapis", + "apiName": "aiplatform", + "apiVersion": "v1", + "language": "python", + "generator": "bazel" + } } ] } \ No newline at end of file diff --git a/synth.py b/synth.py index ee86460430..70bd6f3701 100644 --- a/synth.py +++ b/synth.py @@ -28,37 +28,59 @@ # Generate AI Platform GAPIC layer # ---------------------------------------------------------------------------- -library = gapic.py_library( - service="aiplatform", - version="v1beta1", - bazel_target="//google/cloud/aiplatform/v1beta1:aiplatform-v1beta1-py", -) -s.move( - library, - excludes=[ - ".pre-commit-config.yaml", - "setup.py", - "README.rst", - "docs/index.rst", - "docs/definition_v1beta1/services.rst", - "docs/instance_v1beta1/services.rst", - "docs/params_v1beta1/services.rst", - "docs/prediction_v1beta1/services.rst", - "scripts/fixup_aiplatform_v1beta1_keywords.py", - "scripts/fixup_definition_v1beta1_keywords.py", - "scripts/fixup_instance_v1beta1_keywords.py", - "scripts/fixup_params_v1beta1_keywords.py", - "scripts/fixup_prediction_v1beta1_keywords.py", - "google/cloud/aiplatform/__init__.py", - "google/cloud/aiplatform/v1beta1/schema/**/services/", - "tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py", - "tests/unit/gapic/definition_v1beta1/", - "tests/unit/gapic/instance_v1beta1/", - "tests/unit/gapic/params_v1beta1/", - "tests/unit/gapic/prediction_v1beta1/", - ], -) +versions = ["v1beta1", "v1"] + +for version in versions: + library = gapic.py_library( + service="aiplatform", + version=version, + bazel_target=f"//google/cloud/aiplatform/{version}:aiplatform-{version}-py", + ) + + s.move( + library, + excludes=[ + ".pre-commit-config.yaml", + "setup.py", + "README.rst", + "docs/index.rst", + f"docs/definition_{version}/services.rst", + f"docs/instance_{version}/services.rst", + f"docs/params_{version}/services.rst", + f"docs/prediction_{version}/services.rst", + f"scripts/fixup_aiplatform_{version}_keywords.py", + f"scripts/fixup_definition_{version}_keywords.py", + f"scripts/fixup_instance_{version}_keywords.py", + f"scripts/fixup_params_{version}_keywords.py", + f"scripts/fixup_prediction_{version}_keywords.py", + "google/cloud/aiplatform/__init__.py", + f"google/cloud/aiplatform/{version}/schema/**/services/", + f"tests/unit/gapic/aiplatform_{version}/test_prediction_service.py", + f"tests/unit/gapic/definition_{version}/", + f"tests/unit/gapic/instance_{version}/", + f"tests/unit/gapic/params_{version}/", + f"tests/unit/gapic/prediction_{version}/", + ], + ) + + # --------------------------------------------------------------------- + # Patch each version of the library + # --------------------------------------------------------------------- + + # https://github.com/googleapis/gapic-generator-python/issues/413 + s.replace( + f"google/cloud/aiplatform_{version}/services/prediction_service/client.py", + "request.instances = instances", + "request.instances.extend(instances)", + ) + + # https://github.com/googleapis/gapic-generator-python/issues/672 + s.replace( + "google/cloud/aiplatform_{version}/services/endpoint_service/client.py", + "request.traffic_split.extend\(traffic_split\)", + "request.traffic_split = traffic_split", + ) # ---------------------------------------------------------------------------- # Patch the library @@ -70,32 +92,18 @@ "client_options: ClientOptions.ClientOptions = ", ) -# https://github.com/googleapis/gapic-generator-python/issues/413 -s.replace( - "google/cloud/aiplatform_v1beta1/services/prediction_service/client.py", - "request.instances = instances", - "request.instances.extend(instances)", -) - -# https://github.com/googleapis/gapic-generator-python/issues/672 -s.replace( - "google/cloud/aiplatform_v1beta1/services/endpoint_service/client.py", - "request.traffic_split.extend\(traffic_split\)", - "request.traffic_split = traffic_split", -) - # Generator adds a bad import statement to enhanced type; # need to fix in post-processing steps. -s.replace( - "google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py", - "text_sentiment_pb2 as gcaspi_text_sentiment # type: ignore", - "TextSentimentPredictionInstance") - -s.replace( - "google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py", - "message=gcaspi_text_sentiment.TextSentimentPredictionInstance,", - "message=TextSentimentPredictionInstance,") +#s.replace( +# "google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py", +# "text_sentiment_pb2 as gcaspi_text_sentiment # type: ignore", +# "TextSentimentPredictionInstance") + +#s.replace( +# "google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/types/text_sentiment.py", +# "message=gcaspi_text_sentiment.TextSentimentPredictionInstance,", +# "message=TextSentimentPredictionInstance,") diff --git a/tests/unit/gapic/aiplatform_v1/__init__.py b/tests/unit/gapic/aiplatform_v1/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/unit/gapic/aiplatform_v1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py new file mode 100644 index 0000000000..d03570e876 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_dataset_service.py @@ -0,0 +1,3553 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.dataset_service import ( + DatasetServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.dataset_service import DatasetServiceClient +from google.cloud.aiplatform_v1.services.dataset_service import pagers +from google.cloud.aiplatform_v1.services.dataset_service import transports +from google.cloud.aiplatform_v1.types import annotation +from google.cloud.aiplatform_v1.types import annotation_spec +from google.cloud.aiplatform_v1.types import data_item +from google.cloud.aiplatform_v1.types import dataset +from google.cloud.aiplatform_v1.types import dataset as gca_dataset +from google.cloud.aiplatform_v1.types import dataset_service +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DatasetServiceClient._get_default_mtls_endpoint(None) is None + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DatasetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +def test_dataset_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = DatasetServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] +) +def test_dataset_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_client_get_transport_class(): + transport = DatasetServiceClient.get_transport_class() + available_transports = [ + transports.DatasetServiceGrpcTransport, + ] + assert transport in available_transports + + transport = DatasetServiceClient.get_transport_class("grpc") + assert transport == transports.DatasetServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +def test_dataset_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DatasetServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "true"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc", "false"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + DatasetServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceClient), +) +@mock.patch.object( + DatasetServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DatasetServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_dataset_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DatasetServiceClient, transports.DatasetServiceGrpcTransport, "grpc"), + ( + DatasetServiceAsyncClient, + transports.DatasetServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_dataset_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_dataset_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = DatasetServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_dataset( + transport: str = "grpc", request_type=dataset_service.CreateDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_dataset_from_dict(): + test_create_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.CreateDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.CreateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_dataset_async_from_dict(): + await test_create_dataset_async(request_type=dict) + + +def test_create_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.CreateDatasetRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_dataset( + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + +def test_create_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_dataset( + parent="parent_value", dataset=gca_dataset.Dataset(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + +@pytest.mark.asyncio +async def test_create_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_dataset( + dataset_service.CreateDatasetRequest(), + parent="parent_value", + dataset=gca_dataset.Dataset(name="name_value"), + ) + + +def test_get_dataset( + transport: str = "grpc", request_type=dataset_service.GetDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +def test_get_dataset_from_dict(): + test_get_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.GetDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) + + response = await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_dataset_async_from_dict(): + await test_get_dataset_async(request_type=dict) + + +def test_get_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value = dataset.Dataset() + + client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + + await client.get_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_dataset( + dataset_service.GetDatasetRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_dataset( + dataset_service.GetDatasetRequest(), name="name_value", + ) + + +def test_update_dataset( + transport: str = "grpc", request_type=dataset_service.UpdateDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + + response = client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +def test_update_dataset_from_dict(): + test_update_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.UpdateDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_dataset.Dataset( + name="name_value", + display_name="display_name_value", + metadata_schema_uri="metadata_schema_uri_value", + etag="etag_value", + ) + ) + + response = await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.UpdateDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_dataset.Dataset) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_dataset_async_from_dict(): + await test_update_dataset_async(request_type=dict) + + +def test_update_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = "dataset.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value = gca_dataset.Dataset() + + client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.UpdateDatasetRequest() + request.dataset.name = "dataset.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + + await client.update_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "dataset.name=dataset.name/value",) in kw[ + "metadata" + ] + + +def test_update_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_dataset( + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_dataset.Dataset() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_dataset.Dataset()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_dataset( + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].dataset == gca_dataset.Dataset(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_dataset( + dataset_service.UpdateDatasetRequest(), + dataset=gca_dataset.Dataset(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_list_datasets( + transport: str = "grpc", request_type=dataset_service.ListDatasetsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDatasetsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_datasets_from_dict(): + test_list_datasets(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_datasets_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDatasetsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDatasetsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatasetsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_datasets_async_from_dict(): + await test_list_datasets_async(request_type=dict) + + +def test_list_datasets_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = dataset_service.ListDatasetsResponse() + + client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_datasets_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDatasetsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) + + await client.list_datasets(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_datasets_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_datasets(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_datasets_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_datasets( + dataset_service.ListDatasetsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDatasetsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDatasetsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_datasets(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_datasets_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_datasets( + dataset_service.ListDatasetsRequest(), parent="parent_value", + ) + + +def test_list_datasets_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_datasets(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, dataset.Dataset) for i in results) + + +def test_list_datasets_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_datasets), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + pages = list(client.list_datasets(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_datasets_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + async_pager = await client.list_datasets(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, dataset.Dataset) for i in responses) + + +@pytest.mark.asyncio +async def test_list_datasets_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_datasets), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(), dataset.Dataset(),], + next_page_token="abc", + ), + dataset_service.ListDatasetsResponse(datasets=[], next_page_token="def",), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(),], next_page_token="ghi", + ), + dataset_service.ListDatasetsResponse( + datasets=[dataset.Dataset(), dataset.Dataset(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_datasets(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_dataset( + transport: str = "grpc", request_type=dataset_service.DeleteDatasetRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_dataset_from_dict(): + test_delete_dataset(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_dataset_async( + transport: str = "grpc_asyncio", request_type=dataset_service.DeleteDatasetRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.DeleteDatasetRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_dataset_async_from_dict(): + await test_delete_dataset_async(request_type=dict) + + +def test_delete_dataset_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_dataset_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.DeleteDatasetRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_dataset(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_dataset_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_dataset_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_dataset( + dataset_service.DeleteDatasetRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_dataset), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_dataset(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_dataset_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_dataset( + dataset_service.DeleteDatasetRequest(), name="name_value", + ) + + +def test_import_data( + transport: str = "grpc", request_type=dataset_service.ImportDataRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_import_data_from_dict(): + test_import_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_import_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ImportDataRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ImportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_import_data_async_from_dict(): + await test_import_data_async(request_type=dict) + + +def test_import_data_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_import_data_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ImportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.import_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_import_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.import_data( + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] + + +def test_import_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.import_data( + dataset_service.ImportDataRequest(), + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + +@pytest.mark.asyncio +async def test_import_data_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.import_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.import_data( + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].import_configs == [ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ] + + +@pytest.mark.asyncio +async def test_import_data_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.import_data( + dataset_service.ImportDataRequest(), + name="name_value", + import_configs=[ + dataset.ImportDataConfig(gcs_source=io.GcsSource(uris=["uris_value"])) + ], + ) + + +def test_export_data( + transport: str = "grpc", request_type=dataset_service.ExportDataRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_data_from_dict(): + test_export_data(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_data_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ExportDataRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ExportDataRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_export_data_async_from_dict(): + await test_export_data_async(request_type=dict) + + +def test_export_data_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_data_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ExportDataRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.export_data(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_export_data_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_data( + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) + + +def test_export_data_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_data( + dataset_service.ExportDataRequest(), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + +@pytest.mark.asyncio +async def test_export_data_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_data), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_data( + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].export_config == dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ) + + +@pytest.mark.asyncio +async def test_export_data_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_data( + dataset_service.ExportDataRequest(), + name="name_value", + export_config=dataset.ExportDataConfig( + gcs_destination=io.GcsDestination( + output_uri_prefix="output_uri_prefix_value" + ) + ), + ) + + +def test_list_data_items( + transport: str = "grpc", request_type=dataset_service.ListDataItemsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDataItemsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_items_from_dict(): + test_list_data_items(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_items_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListDataItemsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListDataItemsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataItemsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_data_items_async_from_dict(): + await test_list_data_items_async(request_type=dict) + + +def test_list_data_items_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = dataset_service.ListDataItemsResponse() + + client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_data_items_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListDataItemsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) + + await client.list_data_items(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_items_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_items(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_data_items_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_items( + dataset_service.ListDataItemsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListDataItemsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListDataItemsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_items(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_data_items_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_items( + dataset_service.ListDataItemsRequest(), parent="parent_value", + ) + + +def test_list_data_items_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_data_items(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_item.DataItem) for i in results) + + +def test_list_data_items_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_data_items), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + pages = list(client.list_data_items(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_data_items_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + async_pager = await client.list_data_items(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_item.DataItem) for i in responses) + + +@pytest.mark.asyncio +async def test_list_data_items_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_items), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListDataItemsResponse( + data_items=[ + data_item.DataItem(), + data_item.DataItem(), + data_item.DataItem(), + ], + next_page_token="abc", + ), + dataset_service.ListDataItemsResponse( + data_items=[], next_page_token="def", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(),], next_page_token="ghi", + ), + dataset_service.ListDataItemsResponse( + data_items=[data_item.DataItem(), data_item.DataItem(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_items(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_annotation_spec( + transport: str = "grpc", request_type=dataset_service.GetAnnotationSpecRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + + response = client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.etag == "etag_value" + + +def test_get_annotation_spec_from_dict(): + test_get_annotation_spec(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async( + transport: str = "grpc_asyncio", + request_type=dataset_service.GetAnnotationSpecRequest, +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec( + name="name_value", display_name="display_name_value", etag="etag_value", + ) + ) + + response = await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.GetAnnotationSpecRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, annotation_spec.AnnotationSpec) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_annotation_spec_async_from_dict(): + await test_get_annotation_spec_async(request_type=dict) + + +def test_get_annotation_spec_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = annotation_spec.AnnotationSpec() + + client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_annotation_spec_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.GetAnnotationSpecRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + + await client.get_annotation_spec(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_annotation_spec_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_annotation_spec(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_annotation_spec_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_annotation_spec), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = annotation_spec.AnnotationSpec() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + annotation_spec.AnnotationSpec() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_annotation_spec(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_annotation_spec_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_annotation_spec( + dataset_service.GetAnnotationSpecRequest(), name="name_value", + ) + + +def test_list_annotations( + transport: str = "grpc", request_type=dataset_service.ListAnnotationsRequest +): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListAnnotationsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_annotations_from_dict(): + test_list_annotations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_annotations_async( + transport: str = "grpc_asyncio", request_type=dataset_service.ListAnnotationsRequest +): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == dataset_service.ListAnnotationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListAnnotationsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_annotations_async_from_dict(): + await test_list_annotations_async(request_type=dict) + + +def test_list_annotations_field_headers(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = dataset_service.ListAnnotationsResponse() + + client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_annotations_field_headers_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = dataset_service.ListAnnotationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + + await client.list_annotations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_annotations_flattened(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_annotations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_annotations_flattened_error(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_annotations( + dataset_service.ListAnnotationsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = dataset_service.ListAnnotationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + dataset_service.ListAnnotationsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_annotations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_annotations_flattened_error_async(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_annotations( + dataset_service.ListAnnotationsRequest(), parent="parent_value", + ) + + +def test_list_annotations_pager(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_annotations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, annotation.Annotation) for i in results) + + +def test_list_annotations_pages(): + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_annotations), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + pages = list(client.list_annotations(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_annotations_async_pager(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + async_pager = await client.list_annotations(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, annotation.Annotation) for i in responses) + + +@pytest.mark.asyncio +async def test_list_annotations_async_pages(): + client = DatasetServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_annotations), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + dataset_service.ListAnnotationsResponse( + annotations=[ + annotation.Annotation(), + annotation.Annotation(), + annotation.Annotation(), + ], + next_page_token="abc", + ), + dataset_service.ListAnnotationsResponse( + annotations=[], next_page_token="def", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(),], next_page_token="ghi", + ), + dataset_service.ListAnnotationsResponse( + annotations=[annotation.Annotation(), annotation.Annotation(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_annotations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatasetServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = DatasetServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatasetServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DatasetServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DatasetServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.DatasetServiceGrpcTransport,) + + +def test_dataset_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_dataset_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.DatasetServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_dataset", + "get_dataset", + "update_dataset", + "list_datasets", + "delete_dataset", + "import_data", + "export_data", + "list_data_items", + "get_annotation_spec", + "list_annotations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_dataset_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_dataset_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.dataset_service.transports.DatasetServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.DatasetServiceTransport() + adc.assert_called_once() + + +def test_dataset_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + DatasetServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_dataset_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_dataset_service_host_no_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_dataset_service_host_with_port(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_dataset_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_dataset_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DatasetServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_dataset_service_grpc_lro_client(): + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_dataset_service_grpc_lro_async_client(): + client = DatasetServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_annotation_path(): + project = "squid" + location = "clam" + dataset = "whelk" + data_item = "octopus" + annotation = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}/annotations/{annotation}".format( + project=project, + location=location, + dataset=dataset, + data_item=data_item, + annotation=annotation, + ) + actual = DatasetServiceClient.annotation_path( + project, location, dataset, data_item, annotation + ) + assert expected == actual + + +def test_parse_annotation_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + "data_item": "winkle", + "annotation": "nautilus", + } + path = DatasetServiceClient.annotation_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_path(path) + assert expected == actual + + +def test_annotation_spec_path(): + project = "scallop" + location = "abalone" + dataset = "squid" + annotation_spec = "clam" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/annotationSpecs/{annotation_spec}".format( + project=project, + location=location, + dataset=dataset, + annotation_spec=annotation_spec, + ) + actual = DatasetServiceClient.annotation_spec_path( + project, location, dataset, annotation_spec + ) + assert expected == actual + + +def test_parse_annotation_spec_path(): + expected = { + "project": "whelk", + "location": "octopus", + "dataset": "oyster", + "annotation_spec": "nudibranch", + } + path = DatasetServiceClient.annotation_spec_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_annotation_spec_path(path) + assert expected == actual + + +def test_data_item_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + data_item = "nautilus" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}/dataItems/{data_item}".format( + project=project, location=location, dataset=dataset, data_item=data_item, + ) + actual = DatasetServiceClient.data_item_path(project, location, dataset, data_item) + assert expected == actual + + +def test_parse_data_item_path(): + expected = { + "project": "scallop", + "location": "abalone", + "dataset": "squid", + "data_item": "clam", + } + path = DatasetServiceClient.data_item_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_data_item_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "whelk" + location = "octopus" + dataset = "oyster" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = DatasetServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nudibranch", + "location": "cuttlefish", + "dataset": "mussel", + } + path = DatasetServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = DatasetServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + } + path = DatasetServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder,) + actual = DatasetServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + } + path = DatasetServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization,) + actual = DatasetServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + } + path = DatasetServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project,) + actual = DatasetServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + } + path = DatasetServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = DatasetServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + } + path = DatasetServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DatasetServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = DatasetServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.DatasetServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = DatasetServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py new file mode 100644 index 0000000000..227af94bf8 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_endpoint_service.py @@ -0,0 +1,2634 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.endpoint_service import ( + EndpointServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.endpoint_service import EndpointServiceClient +from google.cloud.aiplatform_v1.services.endpoint_service import pagers +from google.cloud.aiplatform_v1.services.endpoint_service import transports +from google.cloud.aiplatform_v1.types import accelerator_type +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import endpoint +from google.cloud.aiplatform_v1.types import endpoint as gca_endpoint +from google.cloud.aiplatform_v1.types import endpoint_service +from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert EndpointServiceClient._get_default_mtls_endpoint(None) is None + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + EndpointServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +def test_endpoint_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = EndpointServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] +) +def test_endpoint_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_client_get_transport_class(): + transport = EndpointServiceClient.get_transport_class() + available_transports = [ + transports.EndpointServiceGrpcTransport, + ] + assert transport in available_transports + + transport = EndpointServiceClient.get_transport_class("grpc") + assert transport == transports.EndpointServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +def test_endpoint_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(EndpointServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "true", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + EndpointServiceClient, + transports.EndpointServiceGrpcTransport, + "grpc", + "false", + ), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + EndpointServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceClient), +) +@mock.patch.object( + EndpointServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(EndpointServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_endpoint_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (EndpointServiceClient, transports.EndpointServiceGrpcTransport, "grpc"), + ( + EndpointServiceAsyncClient, + transports.EndpointServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_endpoint_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_endpoint_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = EndpointServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_endpoint( + transport: str = "grpc", request_type=endpoint_service.CreateEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_endpoint_from_dict(): + test_create_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.CreateEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.CreateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_endpoint_async_from_dict(): + await test_create_endpoint_async(request_type=dict) + + +def test_create_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.CreateEndpointRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_endpoint( + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + +def test_create_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_endpoint( + parent="parent_value", endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + +@pytest.mark.asyncio +async def test_create_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_endpoint( + endpoint_service.CreateEndpointRequest(), + parent="parent_value", + endpoint=gca_endpoint.Endpoint(name="name_value"), + ) + + +def test_get_endpoint( + transport: str = "grpc", request_type=endpoint_service.GetEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_get_endpoint_from_dict(): + test_get_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.GetEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.GetEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_endpoint_async_from_dict(): + await test_get_endpoint_async(request_type=dict) + + +def test_get_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value = endpoint.Endpoint() + + client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.GetEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + + await client.get_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_endpoint( + endpoint_service.GetEndpointRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(endpoint.Endpoint()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_endpoint( + endpoint_service.GetEndpointRequest(), name="name_value", + ) + + +def test_list_endpoints( + transport: str = "grpc", request_type=endpoint_service.ListEndpointsRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListEndpointsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_endpoints_from_dict(): + test_list_endpoints(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_endpoints_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.ListEndpointsRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.ListEndpointsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListEndpointsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_endpoints_async_from_dict(): + await test_list_endpoints_async(request_type=dict) + + +def test_list_endpoints_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = endpoint_service.ListEndpointsResponse() + + client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_endpoints_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.ListEndpointsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) + + await client.list_endpoints(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_endpoints_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_endpoints(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_endpoints_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_endpoints( + endpoint_service.ListEndpointsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = endpoint_service.ListEndpointsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + endpoint_service.ListEndpointsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_endpoints(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_endpoints_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_endpoints( + endpoint_service.ListEndpointsRequest(), parent="parent_value", + ) + + +def test_list_endpoints_pager(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_endpoints(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, endpoint.Endpoint) for i in results) + + +def test_list_endpoints_pages(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_endpoints), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + pages = list(client.list_endpoints(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_endpoints_async_pager(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + async_pager = await client.list_endpoints(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, endpoint.Endpoint) for i in responses) + + +@pytest.mark.asyncio +async def test_list_endpoints_async_pages(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_endpoints), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + endpoint_service.ListEndpointsResponse( + endpoints=[ + endpoint.Endpoint(), + endpoint.Endpoint(), + endpoint.Endpoint(), + ], + next_page_token="abc", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[], next_page_token="def", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(),], next_page_token="ghi", + ), + endpoint_service.ListEndpointsResponse( + endpoints=[endpoint.Endpoint(), endpoint.Endpoint(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_endpoints(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_endpoint( + transport: str = "grpc", request_type=endpoint_service.UpdateEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + + response = client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +def test_update_endpoint_from_dict(): + test_update_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UpdateEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint( + name="name_value", + display_name="display_name_value", + description="description_value", + etag="etag_value", + ) + ) + + response = await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UpdateEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_endpoint.Endpoint) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_endpoint_async_from_dict(): + await test_update_endpoint_async(request_type=dict) + + +def test_update_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = "endpoint.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = gca_endpoint.Endpoint() + + client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] + + +@pytest.mark.asyncio +async def test_update_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UpdateEndpointRequest() + request.endpoint.name = "endpoint.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) + + await client.update_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint.name=endpoint.name/value",) in kw[ + "metadata" + ] + + +def test_update_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_endpoint.Endpoint() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_endpoint.Endpoint() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_endpoint( + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == gca_endpoint.Endpoint(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_endpoint( + endpoint_service.UpdateEndpointRequest(), + endpoint=gca_endpoint.Endpoint(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_endpoint( + transport: str = "grpc", request_type=endpoint_service.DeleteEndpointRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_endpoint_from_dict(): + test_delete_endpoint(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_endpoint_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeleteEndpointRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeleteEndpointRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_endpoint_async_from_dict(): + await test_delete_endpoint_async(request_type=dict) + + +def test_delete_endpoint_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_endpoint_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeleteEndpointRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_endpoint(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_endpoint_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_endpoint_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_endpoint), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_endpoint(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_endpoint_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_endpoint( + endpoint_service.DeleteEndpointRequest(), name="name_value", + ) + + +def test_deploy_model( + transport: str = "grpc", request_type=endpoint_service.DeployModelRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_deploy_model_from_dict(): + test_deploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_deploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.DeployModelRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.DeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_deploy_model_async_from_dict(): + await test_deploy_model_async(request_type=dict) + + +def test_deploy_model_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_deploy_model_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.DeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.deploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_deploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.deploy_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + + assert args[0].traffic_split == {"key_value": 541} + + +def test_deploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.deploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.deploy_model( + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model == gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ) + + assert args[0].traffic_split == {"key_value": 541} + + +@pytest.mark.asyncio +async def test_deploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.deploy_model( + endpoint_service.DeployModelRequest(), + endpoint="endpoint_value", + deployed_model=gca_endpoint.DeployedModel( + dedicated_resources=machine_resources.DedicatedResources( + machine_spec=machine_resources.MachineSpec( + machine_type="machine_type_value" + ) + ) + ), + traffic_split={"key_value": 541}, + ) + + +def test_undeploy_model( + transport: str = "grpc", request_type=endpoint_service.UndeployModelRequest +): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_undeploy_model_from_dict(): + test_undeploy_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_undeploy_model_async( + transport: str = "grpc_asyncio", request_type=endpoint_service.UndeployModelRequest +): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == endpoint_service.UndeployModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_undeploy_model_async_from_dict(): + await test_undeploy_model_async(request_type=dict) + + +def test_undeploy_model_field_headers(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_undeploy_model_field_headers_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = endpoint_service.UndeployModelRequest() + request.endpoint = "endpoint/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.undeploy_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "endpoint=endpoint/value",) in kw["metadata"] + + +def test_undeploy_model_flattened(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.undeploy_model( + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model_id == "deployed_model_id_value" + + assert args[0].traffic_split == {"key_value": 541} + + +def test_undeploy_model_flattened_error(): + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.undeploy_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.undeploy_model( + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].endpoint == "endpoint_value" + + assert args[0].deployed_model_id == "deployed_model_id_value" + + assert args[0].traffic_split == {"key_value": 541} + + +@pytest.mark.asyncio +async def test_undeploy_model_flattened_error_async(): + client = EndpointServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.undeploy_model( + endpoint_service.UndeployModelRequest(), + endpoint="endpoint_value", + deployed_model_id="deployed_model_id_value", + traffic_split={"key_value": 541}, + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = EndpointServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = EndpointServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.EndpointServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.EndpointServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = EndpointServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.EndpointServiceGrpcTransport,) + + +def test_endpoint_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_endpoint_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.EndpointServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_endpoint", + "get_endpoint", + "list_endpoints", + "update_endpoint", + "delete_endpoint", + "deploy_model", + "undeploy_model", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_endpoint_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_endpoint_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.endpoint_service.transports.EndpointServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.EndpointServiceTransport() + adc.assert_called_once() + + +def test_endpoint_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + EndpointServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_endpoint_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_endpoint_service_host_no_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_endpoint_service_host_with_port(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_endpoint_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_endpoint_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.EndpointServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_endpoint_service_grpc_lro_client(): + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_service_grpc_lro_async_client(): + client = EndpointServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = EndpointServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = EndpointServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = EndpointServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = EndpointServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = EndpointServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = EndpointServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = EndpointServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = EndpointServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = EndpointServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = EndpointServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = EndpointServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = EndpointServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = EndpointServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = EndpointServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = EndpointServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = EndpointServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.EndpointServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = EndpointServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_job_service.py b/tests/unit/gapic/aiplatform_v1/test_job_service.py new file mode 100644 index 0000000000..a471b22658 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_job_service.py @@ -0,0 +1,6161 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.job_service import JobServiceAsyncClient +from google.cloud.aiplatform_v1.services.job_service import JobServiceClient +from google.cloud.aiplatform_v1.services.job_service import pagers +from google.cloud.aiplatform_v1.services.job_service import transports +from google.cloud.aiplatform_v1.types import accelerator_type +from google.cloud.aiplatform_v1.types import batch_prediction_job +from google.cloud.aiplatform_v1.types import ( + batch_prediction_job as gca_batch_prediction_job, +) +from google.cloud.aiplatform_v1.types import completion_stats +from google.cloud.aiplatform_v1.types import custom_job +from google.cloud.aiplatform_v1.types import custom_job as gca_custom_job +from google.cloud.aiplatform_v1.types import data_labeling_job +from google.cloud.aiplatform_v1.types import data_labeling_job as gca_data_labeling_job +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import env_var +from google.cloud.aiplatform_v1.types import hyperparameter_tuning_job +from google.cloud.aiplatform_v1.types import ( + hyperparameter_tuning_job as gca_hyperparameter_tuning_job, +) +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import job_service +from google.cloud.aiplatform_v1.types import job_state +from google.cloud.aiplatform_v1.types import machine_resources +from google.cloud.aiplatform_v1.types import manual_batch_tuning_parameters +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import study +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import duration_pb2 as duration # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore +from google.type import money_pb2 as money # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert JobServiceClient._get_default_mtls_endpoint(None) is None + assert ( + JobServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + JobServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test_job_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = JobServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) +def test_job_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_client_get_transport_class(): + transport = JobServiceClient.get_transport_class() + available_transports = [ + transports.JobServiceGrpcTransport, + ] + assert transport in available_transports + + transport = JobServiceClient.get_transport_class("grpc") + assert transport == transports.JobServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +def test_job_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(JobServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "true"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc", "false"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + JobServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(JobServiceClient) +) +@mock.patch.object( + JobServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(JobServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_job_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (JobServiceClient, transports.JobServiceGrpcTransport, "grpc"), + ( + JobServiceAsyncClient, + transports.JobServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_job_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_job_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = JobServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_custom_job( + transport: str = "grpc", request_type=job_service.CreateCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_custom_job_from_dict(): + test_create_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CreateCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_custom_job_async_from_dict(): + await test_create_custom_job_async(request_type=dict) + + +def test_create_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = gca_custom_job.CustomJob() + + client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateCustomJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) + + await client.create_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_custom_job( + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + + +def test_create_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_custom_job.CustomJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_custom_job( + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].custom_job == gca_custom_job.CustomJob(name="name_value") + + +@pytest.mark.asyncio +async def test_create_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_custom_job( + job_service.CreateCustomJobRequest(), + parent="parent_value", + custom_job=gca_custom_job.CustomJob(name="name_value"), + ) + + +def test_get_custom_job( + transport: str = "grpc", request_type=job_service.GetCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_custom_job_from_dict(): + test_get_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob( + name="name_value", + display_name="display_name_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, custom_job.CustomJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_custom_job_async_from_dict(): + await test_get_custom_job_async(request_type=dict) + + +def test_get_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = custom_job.CustomJob() + + client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) + + await client.get_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_custom_job( + job_service.GetCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_custom_job), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = custom_job.CustomJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + custom_job.CustomJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_custom_job( + job_service.GetCustomJobRequest(), name="name_value", + ) + + +def test_list_custom_jobs( + transport: str = "grpc", request_type=job_service.ListCustomJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListCustomJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_custom_jobs_from_dict(): + test_list_custom_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async( + transport: str = "grpc_asyncio", request_type=job_service.ListCustomJobsRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListCustomJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCustomJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_from_dict(): + await test_list_custom_jobs_async(request_type=dict) + + +def test_list_custom_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = job_service.ListCustomJobsResponse() + + client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_custom_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListCustomJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) + + await client.list_custom_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_custom_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_custom_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_custom_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_custom_jobs( + job_service.ListCustomJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListCustomJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListCustomJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_custom_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_custom_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_custom_jobs( + job_service.ListCustomJobsRequest(), parent="parent_value", + ) + + +def test_list_custom_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_custom_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, custom_job.CustomJob) for i in results) + + +def test_list_custom_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_custom_jobs), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + pages = list(client.list_custom_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + async_pager = await client.list_custom_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, custom_job.CustomJob) for i in responses) + + +@pytest.mark.asyncio +async def test_list_custom_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_custom_jobs), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListCustomJobsResponse( + custom_jobs=[ + custom_job.CustomJob(), + custom_job.CustomJob(), + custom_job.CustomJob(), + ], + next_page_token="abc", + ), + job_service.ListCustomJobsResponse(custom_jobs=[], next_page_token="def",), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(),], next_page_token="ghi", + ), + job_service.ListCustomJobsResponse( + custom_jobs=[custom_job.CustomJob(), custom_job.CustomJob(),], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_custom_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_custom_job( + transport: str = "grpc", request_type=job_service.DeleteCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_custom_job_from_dict(): + test_delete_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.DeleteCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteCustomJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_custom_job_async_from_dict(): + await test_delete_custom_job_async(request_type=dict) + + +def test_delete_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_custom_job( + job_service.DeleteCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_custom_job( + job_service.DeleteCustomJobRequest(), name="name_value", + ) + + +def test_cancel_custom_job( + transport: str = "grpc", request_type=job_service.CancelCustomJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_custom_job_from_dict(): + test_cancel_custom_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_async( + transport: str = "grpc_asyncio", request_type=job_service.CancelCustomJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelCustomJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_custom_job_async_from_dict(): + await test_cancel_custom_job_async(request_type=dict) + + +def test_cancel_custom_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_custom_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelCustomJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_custom_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_custom_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_custom_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_custom_job( + job_service.CancelCustomJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_custom_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_custom_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_custom_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_custom_job( + job_service.CancelCustomJobRequest(), name="name_value", + ) + + +def test_create_data_labeling_job( + transport: str = "grpc", request_type=job_service.CreateDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_create_data_labeling_job_from_dict(): + test_create_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) + + response = await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_async_from_dict(): + await test_create_data_labeling_job_async(request_type=dict) + + +def test_create_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = gca_data_labeling_job.DataLabelingJob() + + client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateDataLabelingJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) + + await client.create_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_data_labeling_job( + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) + + +def test_create_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_data_labeling_job.DataLabelingJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_data_labeling_job( + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].data_labeling_job == gca_data_labeling_job.DataLabelingJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_data_labeling_job( + job_service.CreateDataLabelingJobRequest(), + parent="parent_value", + data_labeling_job=gca_data_labeling_job.DataLabelingJob(name="name_value"), + ) + + +def test_get_data_labeling_job( + transport: str = "grpc", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + + response = client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +def test_get_data_labeling_job_from_dict(): + test_get_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async( + transport: str = "grpc_asyncio", request_type=job_service.GetDataLabelingJobRequest +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob( + name="name_value", + display_name="display_name_value", + datasets=["datasets_value"], + labeler_count=1375, + instruction_uri="instruction_uri_value", + inputs_schema_uri="inputs_schema_uri_value", + state=job_state.JobState.JOB_STATE_QUEUED, + labeling_progress=1810, + specialist_pools=["specialist_pools_value"], + ) + ) + + response = await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, data_labeling_job.DataLabelingJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.datasets == ["datasets_value"] + + assert response.labeler_count == 1375 + + assert response.instruction_uri == "instruction_uri_value" + + assert response.inputs_schema_uri == "inputs_schema_uri_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + assert response.labeling_progress == 1810 + + assert response.specialist_pools == ["specialist_pools_value"] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_async_from_dict(): + await test_get_data_labeling_job_async(request_type=dict) + + +def test_get_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = data_labeling_job.DataLabelingJob() + + client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) + + await client.get_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = data_labeling_job.DataLabelingJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + data_labeling_job.DataLabelingJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_data_labeling_job( + job_service.GetDataLabelingJobRequest(), name="name_value", + ) + + +def test_list_data_labeling_jobs( + transport: str = "grpc", request_type=job_service.ListDataLabelingJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListDataLabelingJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_data_labeling_jobs_from_dict(): + test_list_data_labeling_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListDataLabelingJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListDataLabelingJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDataLabelingJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_from_dict(): + await test_list_data_labeling_jobs_async(request_type=dict) + + +def test_list_data_labeling_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = job_service.ListDataLabelingJobsResponse() + + client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListDataLabelingJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) + + await client.list_data_labeling_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_data_labeling_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_data_labeling_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_data_labeling_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListDataLabelingJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListDataLabelingJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_data_labeling_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_data_labeling_jobs( + job_service.ListDataLabelingJobsRequest(), parent="parent_value", + ) + + +def test_list_data_labeling_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_data_labeling_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in results) + + +def test_list_data_labeling_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_data_labeling_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_data_labeling_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, data_labeling_job.DataLabelingJob) for i in responses) + + +@pytest.mark.asyncio +async def test_list_data_labeling_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_data_labeling_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + next_page_token="abc", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[], next_page_token="def", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[data_labeling_job.DataLabelingJob(),], + next_page_token="ghi", + ), + job_service.ListDataLabelingJobsResponse( + data_labeling_jobs=[ + data_labeling_job.DataLabelingJob(), + data_labeling_job.DataLabelingJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_data_labeling_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_data_labeling_job( + transport: str = "grpc", request_type=job_service.DeleteDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_data_labeling_job_from_dict(): + test_delete_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_async_from_dict(): + await test_delete_data_labeling_job_async(request_type=dict) + + +def test_delete_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_data_labeling_job( + job_service.DeleteDataLabelingJobRequest(), name="name_value", + ) + + +def test_cancel_data_labeling_job( + transport: str = "grpc", request_type=job_service.CancelDataLabelingJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_data_labeling_job_from_dict(): + test_cancel_data_labeling_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelDataLabelingJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelDataLabelingJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_async_from_dict(): + await test_cancel_data_labeling_job_async(request_type=dict) + + +def test_cancel_data_labeling_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelDataLabelingJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_data_labeling_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_data_labeling_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_data_labeling_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_data_labeling_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_data_labeling_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_data_labeling_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_data_labeling_job( + job_service.CancelDataLabelingJobRequest(), name="name_value", + ) + + +def test_create_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_hyperparameter_tuning_job_from_dict(): + test_create_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_async_from_dict(): + await test_create_hyperparameter_tuning_job_async(request_type=dict) + + +def test_create_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateHyperparameterTuningJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) + + await client.create_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_hyperparameter_tuning_job( + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) + + +def test_create_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_hyperparameter_tuning_job.HyperparameterTuningJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_hyperparameter_tuning_job( + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].hyperparameter_tuning_job == gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_hyperparameter_tuning_job( + job_service.CreateHyperparameterTuningJobRequest(), + parent="parent_value", + hyperparameter_tuning_job=gca_hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value" + ), + ) + + +def test_get_hyperparameter_tuning_job( + transport: str = "grpc", request_type=job_service.GetHyperparameterTuningJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_hyperparameter_tuning_job_from_dict(): + test_get_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob( + name="name_value", + display_name="display_name_value", + max_trial_count=1609, + parallel_trial_count=2128, + max_failed_trial_count=2317, + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, hyperparameter_tuning_job.HyperparameterTuningJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.max_trial_count == 1609 + + assert response.parallel_trial_count == 2128 + + assert response.max_failed_trial_count == 2317 + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_async_from_dict(): + await test_get_hyperparameter_tuning_job_async(request_type=dict) + + +def test_get_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) + + await client.get_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = hyperparameter_tuning_job.HyperparameterTuningJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + hyperparameter_tuning_job.HyperparameterTuningJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_hyperparameter_tuning_job( + job_service.GetHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_list_hyperparameter_tuning_jobs( + transport: str = "grpc", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListHyperparameterTuningJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_hyperparameter_tuning_jobs_from_dict(): + test_list_hyperparameter_tuning_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListHyperparameterTuningJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListHyperparameterTuningJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListHyperparameterTuningJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_from_dict(): + await test_list_hyperparameter_tuning_jobs_async(request_type=dict) + + +def test_list_hyperparameter_tuning_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListHyperparameterTuningJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) + + await client.list_hyperparameter_tuning_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_hyperparameter_tuning_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_hyperparameter_tuning_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_hyperparameter_tuning_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListHyperparameterTuningJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListHyperparameterTuningJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_hyperparameter_tuning_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_hyperparameter_tuning_jobs( + job_service.ListHyperparameterTuningJobsRequest(), parent="parent_value", + ) + + +def test_list_hyperparameter_tuning_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_hyperparameter_tuning_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in results + ) + + +def test_list_hyperparameter_tuning_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_hyperparameter_tuning_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_hyperparameter_tuning_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, hyperparameter_tuning_job.HyperparameterTuningJob) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_hyperparameter_tuning_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_hyperparameter_tuning_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="abc", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[], next_page_token="def", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + next_page_token="ghi", + ), + job_service.ListHyperparameterTuningJobsResponse( + hyperparameter_tuning_jobs=[ + hyperparameter_tuning_job.HyperparameterTuningJob(), + hyperparameter_tuning_job.HyperparameterTuningJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_hyperparameter_tuning_jobs(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_hyperparameter_tuning_job_from_dict(): + test_delete_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_async_from_dict(): + await test_delete_hyperparameter_tuning_job_async(request_type=dict) + + +def test_delete_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_hyperparameter_tuning_job( + job_service.DeleteHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_cancel_hyperparameter_tuning_job( + transport: str = "grpc", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_hyperparameter_tuning_job_from_dict(): + test_cancel_hyperparameter_tuning_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelHyperparameterTuningJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelHyperparameterTuningJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_async_from_dict(): + await test_cancel_hyperparameter_tuning_job_async(request_type=dict) + + +def test_cancel_hyperparameter_tuning_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelHyperparameterTuningJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_hyperparameter_tuning_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_hyperparameter_tuning_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_hyperparameter_tuning_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_hyperparameter_tuning_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_hyperparameter_tuning_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_hyperparameter_tuning_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_hyperparameter_tuning_job( + job_service.CancelHyperparameterTuningJobRequest(), name="name_value", + ) + + +def test_create_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CreateBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_create_batch_prediction_job_from_dict(): + test_create_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CreateBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CreateBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_async_from_dict(): + await test_create_batch_prediction_job_async(request_type=dict) + + +def test_create_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CreateBatchPredictionJobRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) + + await client.create_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_batch_prediction_job( + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) + + +def test_create_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_batch_prediction_job.BatchPredictionJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_batch_prediction_job( + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[ + 0 + ].batch_prediction_job == gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_batch_prediction_job( + job_service.CreateBatchPredictionJobRequest(), + parent="parent_value", + batch_prediction_job=gca_batch_prediction_job.BatchPredictionJob( + name="name_value" + ), + ) + + +def test_get_batch_prediction_job( + transport: str = "grpc", request_type=job_service.GetBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + + response = client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +def test_get_batch_prediction_job_from_dict(): + test_get_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.GetBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob( + name="name_value", + display_name="display_name_value", + model="model_value", + state=job_state.JobState.JOB_STATE_QUEUED, + ) + ) + + response = await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.GetBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, batch_prediction_job.BatchPredictionJob) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.model == "model_value" + + assert response.state == job_state.JobState.JOB_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_async_from_dict(): + await test_get_batch_prediction_job_async(request_type=dict) + + +def test_get_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = batch_prediction_job.BatchPredictionJob() + + client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.GetBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) + + await client.get_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = batch_prediction_job.BatchPredictionJob() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + batch_prediction_job.BatchPredictionJob() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_batch_prediction_job( + job_service.GetBatchPredictionJobRequest(), name="name_value", + ) + + +def test_list_batch_prediction_jobs( + transport: str = "grpc", request_type=job_service.ListBatchPredictionJobsRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListBatchPredictionJobsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_batch_prediction_jobs_from_dict(): + test_list_batch_prediction_jobs(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async( + transport: str = "grpc_asyncio", + request_type=job_service.ListBatchPredictionJobsRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.ListBatchPredictionJobsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBatchPredictionJobsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_from_dict(): + await test_list_batch_prediction_jobs_async(request_type=dict) + + +def test_list_batch_prediction_jobs_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = job_service.ListBatchPredictionJobsResponse() + + client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.ListBatchPredictionJobsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) + + await client.list_batch_prediction_jobs(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_batch_prediction_jobs_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_batch_prediction_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_batch_prediction_jobs_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = job_service.ListBatchPredictionJobsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + job_service.ListBatchPredictionJobsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_batch_prediction_jobs(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_batch_prediction_jobs( + job_service.ListBatchPredictionJobsRequest(), parent="parent_value", + ) + + +def test_list_batch_prediction_jobs_pager(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_batch_prediction_jobs(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in results + ) + + +def test_list_batch_prediction_jobs_pages(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = list(client.list_batch_prediction_jobs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pager(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_batch_prediction_jobs(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, batch_prediction_job.BatchPredictionJob) for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_batch_prediction_jobs_async_pages(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_batch_prediction_jobs), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + next_page_token="abc", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[], next_page_token="def", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[batch_prediction_job.BatchPredictionJob(),], + next_page_token="ghi", + ), + job_service.ListBatchPredictionJobsResponse( + batch_prediction_jobs=[ + batch_prediction_job.BatchPredictionJob(), + batch_prediction_job.BatchPredictionJob(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_batch_prediction_jobs(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_batch_prediction_job( + transport: str = "grpc", request_type=job_service.DeleteBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_batch_prediction_job_from_dict(): + test_delete_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.DeleteBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.DeleteBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_async_from_dict(): + await test_delete_batch_prediction_job_async(request_type=dict) + + +def test_delete_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.DeleteBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_batch_prediction_job( + job_service.DeleteBatchPredictionJobRequest(), name="name_value", + ) + + +def test_cancel_batch_prediction_job( + transport: str = "grpc", request_type=job_service.CancelBatchPredictionJobRequest +): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_batch_prediction_job_from_dict(): + test_cancel_batch_prediction_job(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async( + transport: str = "grpc_asyncio", + request_type=job_service.CancelBatchPredictionJobRequest, +): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == job_service.CancelBatchPredictionJobRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_async_from_dict(): + await test_cancel_batch_prediction_job_async(request_type=dict) + + +def test_cancel_batch_prediction_job_field_headers(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + call.return_value = None + + client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_field_headers_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = job_service.CancelBatchPredictionJobRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_batch_prediction_job(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_batch_prediction_job_flattened(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_batch_prediction_job_flattened_error(): + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_batch_prediction_job), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_batch_prediction_job(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_batch_prediction_job_flattened_error_async(): + client = JobServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_batch_prediction_job( + job_service.CancelBatchPredictionJobRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = JobServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = JobServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.JobServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.JobServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = JobServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.JobServiceGrpcTransport,) + + +def test_job_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_job_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.JobServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_custom_job", + "get_custom_job", + "list_custom_jobs", + "delete_custom_job", + "cancel_custom_job", + "create_data_labeling_job", + "get_data_labeling_job", + "list_data_labeling_jobs", + "delete_data_labeling_job", + "cancel_data_labeling_job", + "create_hyperparameter_tuning_job", + "get_hyperparameter_tuning_job", + "list_hyperparameter_tuning_jobs", + "delete_hyperparameter_tuning_job", + "cancel_hyperparameter_tuning_job", + "create_batch_prediction_job", + "get_batch_prediction_job", + "list_batch_prediction_jobs", + "delete_batch_prediction_job", + "cancel_batch_prediction_job", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_job_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_job_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.job_service.transports.JobServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.JobServiceTransport() + adc.assert_called_once() + + +def test_job_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + JobServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_job_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.JobServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_job_service_host_no_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_job_service_host_with_port(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_job_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_job_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.JobServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_job_service_grpc_lro_client(): + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_job_service_grpc_lro_async_client(): + client = JobServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_batch_prediction_job_path(): + project = "squid" + location = "clam" + batch_prediction_job = "whelk" + + expected = "projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}".format( + project=project, location=location, batch_prediction_job=batch_prediction_job, + ) + actual = JobServiceClient.batch_prediction_job_path( + project, location, batch_prediction_job + ) + assert expected == actual + + +def test_parse_batch_prediction_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "batch_prediction_job": "nudibranch", + } + path = JobServiceClient.batch_prediction_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_batch_prediction_job_path(path) + assert expected == actual + + +def test_custom_job_path(): + project = "cuttlefish" + location = "mussel" + custom_job = "winkle" + + expected = "projects/{project}/locations/{location}/customJobs/{custom_job}".format( + project=project, location=location, custom_job=custom_job, + ) + actual = JobServiceClient.custom_job_path(project, location, custom_job) + assert expected == actual + + +def test_parse_custom_job_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "custom_job": "abalone", + } + path = JobServiceClient.custom_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_custom_job_path(path) + assert expected == actual + + +def test_data_labeling_job_path(): + project = "squid" + location = "clam" + data_labeling_job = "whelk" + + expected = "projects/{project}/locations/{location}/dataLabelingJobs/{data_labeling_job}".format( + project=project, location=location, data_labeling_job=data_labeling_job, + ) + actual = JobServiceClient.data_labeling_job_path( + project, location, data_labeling_job + ) + assert expected == actual + + +def test_parse_data_labeling_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "data_labeling_job": "nudibranch", + } + path = JobServiceClient.data_labeling_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_data_labeling_job_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = JobServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } + path = JobServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_hyperparameter_tuning_job_path(): + project = "squid" + location = "clam" + hyperparameter_tuning_job = "whelk" + + expected = "projects/{project}/locations/{location}/hyperparameterTuningJobs/{hyperparameter_tuning_job}".format( + project=project, + location=location, + hyperparameter_tuning_job=hyperparameter_tuning_job, + ) + actual = JobServiceClient.hyperparameter_tuning_job_path( + project, location, hyperparameter_tuning_job + ) + assert expected == actual + + +def test_parse_hyperparameter_tuning_job_path(): + expected = { + "project": "octopus", + "location": "oyster", + "hyperparameter_tuning_job": "nudibranch", + } + path = JobServiceClient.hyperparameter_tuning_job_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_hyperparameter_tuning_job_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = JobServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = JobServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_model_path(path) + assert expected == actual + + +def test_trial_path(): + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" + + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + actual = JobServiceClient.trial_path(project, location, study, trial) + assert expected == actual + + +def test_parse_trial_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } + path = JobServiceClient.trial_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_trial_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "winkle" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = JobServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nautilus", + } + path = JobServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "scallop" + + expected = "folders/{folder}".format(folder=folder,) + actual = JobServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "abalone", + } + path = JobServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "squid" + + expected = "organizations/{organization}".format(organization=organization,) + actual = JobServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "clam", + } + path = JobServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "whelk" + + expected = "projects/{project}".format(project=project,) + actual = JobServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "octopus", + } + path = JobServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "oyster" + location = "nudibranch" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = JobServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "cuttlefish", + "location": "mussel", + } + path = JobServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = JobServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.JobServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = JobServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py new file mode 100644 index 0000000000..9e1ca0513a --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -0,0 +1,1775 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.migration_service import ( + MigrationServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.migration_service import MigrationServiceClient +from google.cloud.aiplatform_v1.services.migration_service import pagers +from google.cloud.aiplatform_v1.services.migration_service import transports +from google.cloud.aiplatform_v1.types import migratable_resource +from google.cloud.aiplatform_v1.types import migration_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert MigrationServiceClient._get_default_mtls_endpoint(None) is None + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + MigrationServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test_migration_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = MigrationServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] +) +def test_migration_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_migration_service_client_get_transport_class(): + transport = MigrationServiceClient.get_transport_class() + available_transports = [ + transports.MigrationServiceGrpcTransport, + ] + assert transport in available_transports + + transport = MigrationServiceClient.get_transport_class("grpc") + assert transport == transports.MigrationServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +def test_migration_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(MigrationServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "true", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + MigrationServiceClient, + transports.MigrationServiceGrpcTransport, + "grpc", + "false", + ), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + MigrationServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceClient), +) +@mock.patch.object( + MigrationServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(MigrationServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_migration_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (MigrationServiceClient, transports.MigrationServiceGrpcTransport, "grpc"), + ( + MigrationServiceAsyncClient, + transports.MigrationServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_migration_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_migration_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = MigrationServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_search_migratable_resources( + transport: str = "grpc", + request_type=migration_service.SearchMigratableResourcesRequest, +): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + + response = client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.SearchMigratableResourcesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_search_migratable_resources_from_dict(): + test_search_migratable_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.SearchMigratableResourcesRequest, +): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.SearchMigratableResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.SearchMigratableResourcesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_from_dict(): + await test_search_migratable_resources_async(request_type=dict) + + +def test_search_migratable_resources_field_headers(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = migration_service.SearchMigratableResourcesResponse() + + client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_search_migratable_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.SearchMigratableResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) + + await client.search_migratable_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_search_migratable_resources_flattened(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.search_migratable_resources(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_search_migratable_resources_flattened_error(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = migration_service.SearchMigratableResourcesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + migration_service.SearchMigratableResourcesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.search_migratable_resources(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_search_migratable_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.search_migratable_resources( + migration_service.SearchMigratableResourcesRequest(), parent="parent_value", + ) + + +def test_search_migratable_resources_pager(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.search_migratable_resources(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in results + ) + + +def test_search_migratable_resources_pages(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = list(client.search_migratable_resources(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pager(): + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + async_pager = await client.search_migratable_resources(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, migratable_resource.MigratableResource) for i in responses + ) + + +@pytest.mark.asyncio +async def test_search_migratable_resources_async_pages(): + client = MigrationServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.search_migratable_resources), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + next_page_token="abc", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[], next_page_token="def", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[migratable_resource.MigratableResource(),], + next_page_token="ghi", + ), + migration_service.SearchMigratableResourcesResponse( + migratable_resources=[ + migratable_resource.MigratableResource(), + migratable_resource.MigratableResource(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.search_migratable_resources(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_batch_migrate_resources( + transport: str = "grpc", request_type=migration_service.BatchMigrateResourcesRequest +): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_batch_migrate_resources_from_dict(): + test_batch_migrate_resources(request_type=dict) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_async( + transport: str = "grpc_asyncio", + request_type=migration_service.BatchMigrateResourcesRequest, +): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == migration_service.BatchMigrateResourcesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_async_from_dict(): + await test_batch_migrate_resources_async(request_type=dict) + + +def test_batch_migrate_resources_field_headers(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_field_headers_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = migration_service.BatchMigrateResourcesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.batch_migrate_resources(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_batch_migrate_resources_flattened(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_migrate_resources( + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] + + +def test_batch_migrate_resources_flattened_error(): + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_migrate_resources), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_migrate_resources( + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].migrate_resource_requests == [ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ] + + +@pytest.mark.asyncio +async def test_batch_migrate_resources_flattened_error_async(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_migrate_resources( + migration_service.BatchMigrateResourcesRequest(), + parent="parent_value", + migrate_resource_requests=[ + migration_service.MigrateResourceRequest( + migrate_ml_engine_model_version_config=migration_service.MigrateResourceRequest.MigrateMlEngineModelVersionConfig( + endpoint="endpoint_value" + ) + ) + ], + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = MigrationServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = MigrationServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.MigrationServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.MigrationServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = MigrationServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.MigrationServiceGrpcTransport,) + + +def test_migration_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_migration_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.MigrationServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "search_migratable_resources", + "batch_migrate_resources", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_migration_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_migration_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.migration_service.transports.MigrationServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.MigrationServiceTransport() + adc.assert_called_once() + + +def test_migration_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + MigrationServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_migration_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_migration_service_host_no_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_migration_service_host_with_port(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_migration_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_migration_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.MigrationServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_migration_service_grpc_lro_client(): + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_migration_service_grpc_lro_async_client(): + client = MigrationServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_annotated_dataset_path(): + project = "squid" + dataset = "clam" + annotated_dataset = "whelk" + + expected = "projects/{project}/datasets/{dataset}/annotatedDatasets/{annotated_dataset}".format( + project=project, dataset=dataset, annotated_dataset=annotated_dataset, + ) + actual = MigrationServiceClient.annotated_dataset_path( + project, dataset, annotated_dataset + ) + assert expected == actual + + +def test_parse_annotated_dataset_path(): + expected = { + "project": "octopus", + "dataset": "oyster", + "annotated_dataset": "nudibranch", + } + path = MigrationServiceClient.annotated_dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_annotated_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "cuttlefish" + location = "mussel" + dataset = "winkle" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "dataset": "abalone", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "squid" + location = "clam" + dataset = "whelk" + + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project=project, location=location, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, location, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "octopus", + "location": "oyster", + "dataset": "nudibranch", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_dataset_path(): + project = "cuttlefish" + dataset = "mussel" + + expected = "projects/{project}/datasets/{dataset}".format( + project=project, dataset=dataset, + ) + actual = MigrationServiceClient.dataset_path(project, dataset) + assert expected == actual + + +def test_parse_dataset_path(): + expected = { + "project": "winkle", + "dataset": "nautilus", + } + path = MigrationServiceClient.dataset_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_dataset_path(path) + assert expected == actual + + +def test_model_path(): + project = "scallop" + location = "abalone" + model = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + + +def test_model_path(): + project = "oyster" + location = "nudibranch" + model = "cuttlefish" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = MigrationServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "mussel", + "location": "winkle", + "model": "nautilus", + } + path = MigrationServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_model_path(path) + assert expected == actual + + +def test_version_path(): + project = "scallop" + model = "abalone" + version = "squid" + + expected = "projects/{project}/models/{model}/versions/{version}".format( + project=project, model=model, version=version, + ) + actual = MigrationServiceClient.version_path(project, model, version) + assert expected == actual + + +def test_parse_version_path(): + expected = { + "project": "clam", + "model": "whelk", + "version": "octopus", + } + path = MigrationServiceClient.version_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_version_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = MigrationServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = MigrationServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + + expected = "folders/{folder}".format(folder=folder,) + actual = MigrationServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = MigrationServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + + expected = "organizations/{organization}".format(organization=organization,) + actual = MigrationServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = MigrationServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + + expected = "projects/{project}".format(project=project,) + actual = MigrationServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = MigrationServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = MigrationServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = MigrationServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = MigrationServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = MigrationServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.MigrationServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = MigrationServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py new file mode 100644 index 0000000000..f03d9e5d31 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -0,0 +1,3729 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.model_service import ModelServiceAsyncClient +from google.cloud.aiplatform_v1.services.model_service import ModelServiceClient +from google.cloud.aiplatform_v1.services.model_service import pagers +from google.cloud.aiplatform_v1.services.model_service import transports +from google.cloud.aiplatform_v1.types import deployed_model_ref +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import env_var +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import model as gca_model +from google.cloud.aiplatform_v1.types import model_evaluation +from google.cloud.aiplatform_v1.types import model_evaluation_slice +from google.cloud.aiplatform_v1.types import model_service +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test_model_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = ModelServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) +def test_model_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + available_transports = [ + transports.ModelServiceGrpcTransport, + ] + assert transport in available_transports + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_model_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_upload_model( + transport: str = "grpc", request_type=model_service.UploadModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_upload_model_from_dict(): + test_upload_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_upload_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UploadModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UploadModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_upload_model_async_from_dict(): + await test_upload_model_async(request_type=dict) + + +def test_upload_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_upload_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UploadModelRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.upload_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_upload_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.upload_model( + parent="parent_value", model=gca_model.Model(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].model == gca_model.Model(name="name_value") + + +def test_upload_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.upload_model( + model_service.UploadModelRequest(), + parent="parent_value", + model=gca_model.Model(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_upload_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.upload_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.upload_model( + parent="parent_value", model=gca_model.Model(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].model == gca_model.Model(name="name_value") + + +@pytest.mark.asyncio +async def test_upload_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.upload_model( + model_service.UploadModelRequest(), + parent="parent_value", + model=gca_model.Model(name="name_value"), + ) + + +def test_get_model(transport: str = "grpc", request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +def test_get_model_from_dict(): + test_get_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) + + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) + + +def test_get_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = model.Model() + + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), name="name_value", + ) + + +def test_list_models( + transport: str = "grpc", request_type=model_service.ListModelsRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_models_from_dict(): + test_list_models(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse(next_page_token="next_page_token_value",) + ) + + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) + + +def test_list_models_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = model_service.ListModelsResponse() + + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_models_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) + + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_models_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_models_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), parent="parent_value", + ) + + +def test_list_models_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_models(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model.Model) for i in results) + + +def test_list_models_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + async_pager = await client.list_models(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) for i in responses) + + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[model.Model(), model.Model(), model.Model(),], + next_page_token="abc", + ), + model_service.ListModelsResponse(models=[], next_page_token="def",), + model_service.ListModelsResponse( + models=[model.Model(),], next_page_token="ghi", + ), + model_service.ListModelsResponse(models=[model.Model(), model.Model(),],), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_models(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_model( + transport: str = "grpc", request_type=model_service.UpdateModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=["supported_input_storage_formats_value"], + supported_output_storage_formats=["supported_output_storage_formats_value"], + etag="etag_value", + ) + + response = client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +def test_update_model_from_dict(): + test_update_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_model.Model( + name="name_value", + display_name="display_name_value", + description="description_value", + metadata_schema_uri="metadata_schema_uri_value", + training_pipeline="training_pipeline_value", + artifact_uri="artifact_uri_value", + supported_deployment_resources_types=[ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ], + supported_input_storage_formats=[ + "supported_input_storage_formats_value" + ], + supported_output_storage_formats=[ + "supported_output_storage_formats_value" + ], + etag="etag_value", + ) + ) + + response = await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.UpdateModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_model.Model) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.description == "description_value" + + assert response.metadata_schema_uri == "metadata_schema_uri_value" + + assert response.training_pipeline == "training_pipeline_value" + + assert response.artifact_uri == "artifact_uri_value" + + assert response.supported_deployment_resources_types == [ + gca_model.Model.DeploymentResourcesType.DEDICATED_RESOURCES + ] + + assert response.supported_input_storage_formats == [ + "supported_input_storage_formats_value" + ] + + assert response.supported_output_storage_formats == [ + "supported_output_storage_formats_value" + ] + + assert response.etag == "etag_value" + + +@pytest.mark.asyncio +async def test_update_model_async_from_dict(): + await test_update_model_async(request_type=dict) + + +def test_update_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = "model.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value = gca_model.Model() + + client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateModelRequest() + request.model.name = "model.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + + await client.update_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "model.name=model.name/value",) in kw["metadata"] + + +def test_update_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_model( + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = gca_model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(gca_model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_model( + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].model == gca_model.Model(name="name_value") + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_model( + model_service.UpdateModelRequest(), + model=gca_model.Model(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_model( + transport: str = "grpc", request_type=model_service.DeleteModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_model_from_dict(): + test_delete_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.DeleteModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_model_async_from_dict(): + await test_delete_model_async(request_type=dict) + + +def test_delete_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_model( + model_service.DeleteModelRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_model(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_model( + model_service.DeleteModelRequest(), name="name_value", + ) + + +def test_export_model( + transport: str = "grpc", request_type=model_service.ExportModelRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_export_model_from_dict(): + test_export_model(request_type=dict) + + +@pytest.mark.asyncio +async def test_export_model_async( + transport: str = "grpc_asyncio", request_type=model_service.ExportModelRequest +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ExportModelRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_export_model_async_from_dict(): + await test_export_model_async(request_type=dict) + + +def test_export_model_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_export_model_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ExportModelRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.export_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_export_model_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.export_model( + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) + + +def test_export_model_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.export_model( + model_service.ExportModelRequest(), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + +@pytest.mark.asyncio +async def test_export_model_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.export_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.export_model( + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + assert args[0].output_config == model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ) + + +@pytest.mark.asyncio +async def test_export_model_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.export_model( + model_service.ExportModelRequest(), + name="name_value", + output_config=model_service.ExportModelRequest.OutputConfig( + export_format_id="export_format_id_value" + ), + ) + + +def test_get_model_evaluation( + transport: str = "grpc", request_type=model_service.GetModelEvaluationRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + + response = client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + assert response.slice_dimensions == ["slice_dimensions_value"] + + +def test_get_model_evaluation_from_dict(): + test_get_model_evaluation(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation( + name="name_value", + metrics_schema_uri="metrics_schema_uri_value", + slice_dimensions=["slice_dimensions_value"], + ) + ) + + response = await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation.ModelEvaluation) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + assert response.slice_dimensions == ["slice_dimensions_value"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_async_from_dict(): + await test_get_model_evaluation_async(request_type=dict) + + +def test_get_model_evaluation_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = model_evaluation.ModelEvaluation() + + client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) + + await client.get_model_evaluation(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_evaluation_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation.ModelEvaluation() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation.ModelEvaluation() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation( + model_service.GetModelEvaluationRequest(), name="name_value", + ) + + +def test_list_model_evaluations( + transport: str = "grpc", request_type=model_service.ListModelEvaluationsRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelEvaluationsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluations_from_dict(): + test_list_model_evaluations(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationsRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_from_dict(): + await test_list_model_evaluations_async(request_type=dict) + + +def test_list_model_evaluations_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationsResponse() + + client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_model_evaluations_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) + + await client.list_model_evaluations(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluations_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_model_evaluations_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluations(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluations_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluations( + model_service.ListModelEvaluationsRequest(), parent="parent_value", + ) + + +def test_list_model_evaluations_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_model_evaluations(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in results) + + +def test_list_model_evaluations_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluations(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluations(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model_evaluation.ModelEvaluation) for i in responses) + + +@pytest.mark.asyncio +async def test_list_model_evaluations_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluations), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[], next_page_token="def", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[model_evaluation.ModelEvaluation(),], + next_page_token="ghi", + ), + model_service.ListModelEvaluationsResponse( + model_evaluations=[ + model_evaluation.ModelEvaluation(), + model_evaluation.ModelEvaluation(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_model_evaluations(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_model_evaluation_slice( + transport: str = "grpc", request_type=model_service.GetModelEvaluationSliceRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + + response = client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + +def test_get_model_evaluation_slice_from_dict(): + test_get_model_evaluation_slice(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async( + transport: str = "grpc_asyncio", + request_type=model_service.GetModelEvaluationSliceRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice( + name="name_value", metrics_schema_uri="metrics_schema_uri_value", + ) + ) + + response = await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.GetModelEvaluationSliceRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, model_evaluation_slice.ModelEvaluationSlice) + + assert response.name == "name_value" + + assert response.metrics_schema_uri == "metrics_schema_uri_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_async_from_dict(): + await test_get_model_evaluation_slice_async(request_type=dict) + + +def test_get_model_evaluation_slice_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelEvaluationSliceRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) + + await client.get_model_evaluation_slice(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_model_evaluation_slice_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model_evaluation_slice(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_model_evaluation_slice_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_model_evaluation_slice), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_evaluation_slice.ModelEvaluationSlice() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_evaluation_slice.ModelEvaluationSlice() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model_evaluation_slice(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_model_evaluation_slice_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model_evaluation_slice( + model_service.GetModelEvaluationSliceRequest(), name="name_value", + ) + + +def test_list_model_evaluation_slices( + transport: str = "grpc", request_type=model_service.ListModelEvaluationSlicesRequest +): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListModelEvaluationSlicesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_model_evaluation_slices_from_dict(): + test_list_model_evaluation_slices(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async( + transport: str = "grpc_asyncio", + request_type=model_service.ListModelEvaluationSlicesRequest, +): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == model_service.ListModelEvaluationSlicesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelEvaluationSlicesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_from_dict(): + await test_list_model_evaluation_slices_async(request_type=dict) + + +def test_list_model_evaluation_slices_field_headers(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_field_headers_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.ListModelEvaluationSlicesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) + + await client.list_model_evaluation_slices(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_model_evaluation_slices_flattened(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_model_evaluation_slices(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_model_evaluation_slices_flattened_error(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelEvaluationSlicesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelEvaluationSlicesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_model_evaluation_slices(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_flattened_error_async(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_model_evaluation_slices( + model_service.ListModelEvaluationSlicesRequest(), parent="parent_value", + ) + + +def test_list_model_evaluation_slices_pager(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_model_evaluation_slices(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) for i in results + ) + + +def test_list_model_evaluation_slices_pages(): + client = ModelServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = list(client.list_model_evaluation_slices(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pager(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_model_evaluation_slices(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all( + isinstance(i, model_evaluation_slice.ModelEvaluationSlice) + for i in responses + ) + + +@pytest.mark.asyncio +async def test_list_model_evaluation_slices_async_pages(): + client = ModelServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_model_evaluation_slices), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="abc", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[], next_page_token="def", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + ], + next_page_token="ghi", + ), + model_service.ListModelEvaluationSlicesResponse( + model_evaluation_slices=[ + model_evaluation_slice.ModelEvaluationSlice(), + model_evaluation_slice.ModelEvaluationSlice(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in ( + await client.list_model_evaluation_slices(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.ModelServiceGrpcTransport,) + + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "upload_model", + "get_model", + "list_models", + "update_model", + "delete_model", + "export_model", + "get_model_evaluation", + "list_model_evaluations", + "get_model_evaluation_slice", + "list_model_evaluation_slices", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_model_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_model_service_host_no_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_model_service_host_with_port(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_model_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_service_grpc_lro_async_client(): + client = ModelServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = ModelServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = ModelServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = ModelServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + + +def test_model_evaluation_path(): + project = "squid" + location = "clam" + model = "whelk" + evaluation = "octopus" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}".format( + project=project, location=location, model=model, evaluation=evaluation, + ) + actual = ModelServiceClient.model_evaluation_path( + project, location, model, evaluation + ) + assert expected == actual + + +def test_parse_model_evaluation_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "model": "cuttlefish", + "evaluation": "mussel", + } + path = ModelServiceClient.model_evaluation_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_path(path) + assert expected == actual + + +def test_model_evaluation_slice_path(): + project = "winkle" + location = "nautilus" + model = "scallop" + evaluation = "abalone" + slice = "squid" + + expected = "projects/{project}/locations/{location}/models/{model}/evaluations/{evaluation}/slices/{slice}".format( + project=project, + location=location, + model=model, + evaluation=evaluation, + slice=slice, + ) + actual = ModelServiceClient.model_evaluation_slice_path( + project, location, model, evaluation, slice + ) + assert expected == actual + + +def test_parse_model_evaluation_slice_path(): + expected = { + "project": "clam", + "location": "whelk", + "model": "octopus", + "evaluation": "oyster", + "slice": "nudibranch", + } + path = ModelServiceClient.model_evaluation_slice_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_evaluation_slice_path(path) + assert expected == actual + + +def test_training_pipeline_path(): + project = "cuttlefish" + location = "mussel" + training_pipeline = "winkle" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = ModelServiceClient.training_pipeline_path( + project, location, training_pipeline + ) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "training_pipeline": "abalone", + } + path = ModelServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_training_pipeline_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + + expected = "folders/{folder}".format(folder=folder,) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + + expected = "organizations/{organization}".format(organization=organization,) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + + expected = "projects/{project}".format(project=project,) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = ModelServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py new file mode 100644 index 0000000000..23619209b0 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -0,0 +1,2288 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.pipeline_service import ( + PipelineServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.pipeline_service import PipelineServiceClient +from google.cloud.aiplatform_v1.services.pipeline_service import pagers +from google.cloud.aiplatform_v1.services.pipeline_service import transports +from google.cloud.aiplatform_v1.types import deployed_model_ref +from google.cloud.aiplatform_v1.types import encryption_spec +from google.cloud.aiplatform_v1.types import env_var +from google.cloud.aiplatform_v1.types import io +from google.cloud.aiplatform_v1.types import model +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import pipeline_service +from google.cloud.aiplatform_v1.types import pipeline_state +from google.cloud.aiplatform_v1.types import training_pipeline +from google.cloud.aiplatform_v1.types import training_pipeline as gca_training_pipeline +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import any_pb2 as gp_any # type: ignore +from google.protobuf import field_mask_pb2 as field_mask # type: ignore +from google.protobuf import struct_pb2 as struct # type: ignore +from google.protobuf import timestamp_pb2 as timestamp # type: ignore +from google.rpc import status_pb2 as status # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PipelineServiceClient._get_default_mtls_endpoint(None) is None + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PipelineServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +def test_pipeline_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = PipelineServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] +) +def test_pipeline_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_client_get_transport_class(): + transport = PipelineServiceClient.get_transport_class() + available_transports = [ + transports.PipelineServiceGrpcTransport, + ] + assert transport in available_transports + + transport = PipelineServiceClient.get_transport_class("grpc") + assert transport == transports.PipelineServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +def test_pipeline_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PipelineServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "true", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PipelineServiceClient, + transports.PipelineServiceGrpcTransport, + "grpc", + "false", + ), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + PipelineServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceClient), +) +@mock.patch.object( + PipelineServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PipelineServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_pipeline_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PipelineServiceClient, transports.PipelineServiceGrpcTransport, "grpc"), + ( + PipelineServiceAsyncClient, + transports.PipelineServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_pipeline_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_pipeline_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = PipelineServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CreateTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_create_training_pipeline_from_dict(): + test_create_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CreateTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) + + response = await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CreateTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, gca_training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_create_training_pipeline_async_from_dict(): + await test_create_training_pipeline_async(request_type=dict) + + +def test_create_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = gca_training_pipeline.TrainingPipeline() + + client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CreateTrainingPipelineRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) + + await client.create_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_training_pipeline( + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) + + +def test_create_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gca_training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gca_training_pipeline.TrainingPipeline() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_training_pipeline( + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].training_pipeline == gca_training_pipeline.TrainingPipeline( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_training_pipeline( + pipeline_service.CreateTrainingPipelineRequest(), + parent="parent_value", + training_pipeline=gca_training_pipeline.TrainingPipeline(name="name_value"), + ) + + +def test_get_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.GetTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + + response = client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +def test_get_training_pipeline_from_dict(): + test_get_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.GetTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline( + name="name_value", + display_name="display_name_value", + training_task_definition="training_task_definition_value", + state=pipeline_state.PipelineState.PIPELINE_STATE_QUEUED, + ) + ) + + response = await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.GetTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, training_pipeline.TrainingPipeline) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.training_task_definition == "training_task_definition_value" + + assert response.state == pipeline_state.PipelineState.PIPELINE_STATE_QUEUED + + +@pytest.mark.asyncio +async def test_get_training_pipeline_async_from_dict(): + await test_get_training_pipeline_async(request_type=dict) + + +def test_get_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = training_pipeline.TrainingPipeline() + + client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.GetTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) + + await client.get_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = training_pipeline.TrainingPipeline() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + training_pipeline.TrainingPipeline() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_training_pipeline( + pipeline_service.GetTrainingPipelineRequest(), name="name_value", + ) + + +def test_list_training_pipelines( + transport: str = "grpc", request_type=pipeline_service.ListTrainingPipelinesRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListTrainingPipelinesPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_training_pipelines_from_dict(): + test_list_training_pipelines(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.ListTrainingPipelinesRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.ListTrainingPipelinesRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTrainingPipelinesAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_from_dict(): + await test_list_training_pipelines_async(request_type=dict) + + +def test_list_training_pipelines_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_training_pipelines_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.ListTrainingPipelinesRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) + + await client.list_training_pipelines(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_training_pipelines_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_training_pipelines(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_training_pipelines_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = pipeline_service.ListTrainingPipelinesResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + pipeline_service.ListTrainingPipelinesResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_training_pipelines(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_training_pipelines_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_training_pipelines( + pipeline_service.ListTrainingPipelinesRequest(), parent="parent_value", + ) + + +def test_list_training_pipelines_pager(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_training_pipelines(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in results) + + +def test_list_training_pipelines_pages(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = list(client.list_training_pipelines(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pager(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_training_pipelines(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, training_pipeline.TrainingPipeline) for i in responses) + + +@pytest.mark.asyncio +async def test_list_training_pipelines_async_pages(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_training_pipelines), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + next_page_token="abc", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[], next_page_token="def", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[training_pipeline.TrainingPipeline(),], + next_page_token="ghi", + ), + pipeline_service.ListTrainingPipelinesResponse( + training_pipelines=[ + training_pipeline.TrainingPipeline(), + training_pipeline.TrainingPipeline(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_training_pipelines(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.DeleteTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_training_pipeline_from_dict(): + test_delete_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.DeleteTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.DeleteTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_async_from_dict(): + await test_delete_training_pipeline_async(request_type=dict) + + +def test_delete_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.DeleteTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_training_pipeline( + pipeline_service.DeleteTrainingPipelineRequest(), name="name_value", + ) + + +def test_cancel_training_pipeline( + transport: str = "grpc", request_type=pipeline_service.CancelTrainingPipelineRequest +): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + response = client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +def test_cancel_training_pipeline_from_dict(): + test_cancel_training_pipeline(request_type=dict) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async( + transport: str = "grpc_asyncio", + request_type=pipeline_service.CancelTrainingPipelineRequest, +): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + response = await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == pipeline_service.CancelTrainingPipelineRequest() + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_async_from_dict(): + await test_cancel_training_pipeline_async(request_type=dict) + + +def test_cancel_training_pipeline_field_headers(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + call.return_value = None + + client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_field_headers_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = pipeline_service.CancelTrainingPipelineRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + + await client.cancel_training_pipeline(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_cancel_training_pipeline_flattened(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.cancel_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_cancel_training_pipeline_flattened_error(): + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.cancel_training_pipeline), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.cancel_training_pipeline(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_cancel_training_pipeline_flattened_error_async(): + client = PipelineServiceAsyncClient(credentials=credentials.AnonymousCredentials(),) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.cancel_training_pipeline( + pipeline_service.CancelTrainingPipelineRequest(), name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PipelineServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = PipelineServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PipelineServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PipelineServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PipelineServiceClient(credentials=credentials.AnonymousCredentials(),) + assert isinstance(client.transport, transports.PipelineServiceGrpcTransport,) + + +def test_pipeline_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_pipeline_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.PipelineServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_training_pipeline", + "get_training_pipeline", + "list_training_pipelines", + "delete_training_pipeline", + "cancel_training_pipeline", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_pipeline_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_pipeline_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.pipeline_service.transports.PipelineServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.PipelineServiceTransport() + adc.assert_called_once() + + +def test_pipeline_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + PipelineServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_pipeline_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_pipeline_service_host_no_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_pipeline_service_host_with_port(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_pipeline_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_pipeline_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PipelineServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_pipeline_service_grpc_lro_client(): + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_pipeline_service_grpc_lro_async_client(): + client = PipelineServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_endpoint_path(): + project = "squid" + location = "clam" + endpoint = "whelk" + + expected = "projects/{project}/locations/{location}/endpoints/{endpoint}".format( + project=project, location=location, endpoint=endpoint, + ) + actual = PipelineServiceClient.endpoint_path(project, location, endpoint) + assert expected == actual + + +def test_parse_endpoint_path(): + expected = { + "project": "octopus", + "location": "oyster", + "endpoint": "nudibranch", + } + path = PipelineServiceClient.endpoint_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_endpoint_path(path) + assert expected == actual + + +def test_model_path(): + project = "cuttlefish" + location = "mussel" + model = "winkle" + + expected = "projects/{project}/locations/{location}/models/{model}".format( + project=project, location=location, model=model, + ) + actual = PipelineServiceClient.model_path(project, location, model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "project": "nautilus", + "location": "scallop", + "model": "abalone", + } + path = PipelineServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_model_path(path) + assert expected == actual + + +def test_training_pipeline_path(): + project = "squid" + location = "clam" + training_pipeline = "whelk" + + expected = "projects/{project}/locations/{location}/trainingPipelines/{training_pipeline}".format( + project=project, location=location, training_pipeline=training_pipeline, + ) + actual = PipelineServiceClient.training_pipeline_path( + project, location, training_pipeline + ) + assert expected == actual + + +def test_parse_training_pipeline_path(): + expected = { + "project": "octopus", + "location": "oyster", + "training_pipeline": "nudibranch", + } + path = PipelineServiceClient.training_pipeline_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_training_pipeline_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = PipelineServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = PipelineServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = PipelineServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = PipelineServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = PipelineServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = PipelineServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = PipelineServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = PipelineServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = PipelineServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = PipelineServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PipelineServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = PipelineServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.PipelineServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = PipelineServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py new file mode 100644 index 0000000000..e2be66e2c7 --- /dev/null +++ b/tests/unit/gapic/aiplatform_v1/test_specialist_pool_service.py @@ -0,0 +1,2325 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import mock + +import grpc +from grpc.experimental import aio +import math +import pytest +from proto.marshal.rules.dates import DurationRule, TimestampRule + +from google import auth +from google.api_core import client_options +from google.api_core import exceptions +from google.api_core import future +from google.api_core import gapic_v1 +from google.api_core import grpc_helpers +from google.api_core import grpc_helpers_async +from google.api_core import operation_async # type: ignore +from google.api_core import operations_v1 +from google.auth import credentials +from google.auth.exceptions import MutualTLSChannelError +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + SpecialistPoolServiceAsyncClient, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import ( + SpecialistPoolServiceClient, +) +from google.cloud.aiplatform_v1.services.specialist_pool_service import pagers +from google.cloud.aiplatform_v1.services.specialist_pool_service import transports +from google.cloud.aiplatform_v1.types import operation as gca_operation +from google.cloud.aiplatform_v1.types import specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool as gca_specialist_pool +from google.cloud.aiplatform_v1.types import specialist_pool_service +from google.longrunning import operations_pb2 +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 as field_mask # type: ignore + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert SpecialistPoolServiceClient._get_default_mtls_endpoint(None) is None + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + SpecialistPoolServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test_specialist_pool_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = SpecialistPoolServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize( + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] +) +def test_specialist_pool_service_client_from_service_account_file(client_class): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file("dummy/file/path.json") + assert client.transport._credentials == creds + + client = client_class.from_service_account_json("dummy/file/path.json") + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_client_get_transport_class(): + transport = SpecialistPoolServiceClient.get_transport_class() + available_transports = [ + transports.SpecialistPoolServiceGrpcTransport, + ] + assert transport in available_transports + + transport = SpecialistPoolServiceClient.get_transport_class("grpc") + assert transport == transports.SpecialistPoolServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +def test_specialist_pool_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(SpecialistPoolServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError): + client = client_class() + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError): + client = client_class() + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "true", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + "false", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ], +) +@mock.patch.object( + SpecialistPoolServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceClient), +) +@mock.patch.object( + SpecialistPoolServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(SpecialistPoolServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_specialist_pool_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions(scopes=["1", "2"],) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + ( + SpecialistPoolServiceClient, + transports.SpecialistPoolServiceGrpcTransport, + "grpc", + ), + ( + SpecialistPoolServiceAsyncClient, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + ], +) +def test_specialist_pool_service_client_client_options_credentials_file( + client_class, transport_class, transport_name +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_specialist_pool_service_client_client_options_from_dict(): + with mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = SpecialistPoolServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + + +def test_create_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_specialist_pool_from_dict(): + test_create_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.CreateSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.CreateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_async_from_dict(): + await test_create_specialist_pool_async(request_type=dict) + + +def test_create_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.CreateSpecialistPoolRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.create_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_create_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_specialist_pool( + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + +def test_create_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_specialist_pool( + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + +@pytest.mark.asyncio +async def test_create_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_specialist_pool( + specialist_pool_service.CreateSpecialistPoolRequest(), + parent="parent_value", + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + ) + + +def test_get_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + + response = client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + + +def test_get_specialist_pool_from_dict(): + test_get_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.GetSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool( + name="name_value", + display_name="display_name_value", + specialist_managers_count=2662, + specialist_manager_emails=["specialist_manager_emails_value"], + pending_data_labeling_jobs=["pending_data_labeling_jobs_value"], + ) + ) + + response = await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.GetSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, specialist_pool.SpecialistPool) + + assert response.name == "name_value" + + assert response.display_name == "display_name_value" + + assert response.specialist_managers_count == 2662 + + assert response.specialist_manager_emails == ["specialist_manager_emails_value"] + + assert response.pending_data_labeling_jobs == ["pending_data_labeling_jobs_value"] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_async_from_dict(): + await test_get_specialist_pool_async(request_type=dict) + + +def test_get_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = specialist_pool.SpecialistPool() + + client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.GetSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) + + await client.get_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_get_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_get_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool.SpecialistPool() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool.SpecialistPool() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_get_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_specialist_pool( + specialist_pool_service.GetSpecialistPoolRequest(), name="name_value", + ) + + +def test_list_specialist_pools( + transport: str = "grpc", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + + response = client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + # Establish that the response is the type that we expect. + + assert isinstance(response, pagers.ListSpecialistPoolsPager) + + assert response.next_page_token == "next_page_token_value" + + +def test_list_specialist_pools_from_dict(): + test_list_specialist_pools(request_type=dict) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.ListSpecialistPoolsRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse( + next_page_token="next_page_token_value", + ) + ) + + response = await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.ListSpecialistPoolsRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListSpecialistPoolsAsyncPager) + + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_from_dict(): + await test_list_specialist_pools_async(request_type=dict) + + +def test_list_specialist_pools_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_specialist_pools_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.ListSpecialistPoolsRequest() + request.parent = "parent/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) + + await client.list_specialist_pools(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "parent=parent/value",) in kw["metadata"] + + +def test_list_specialist_pools_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_specialist_pools(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +def test_list_specialist_pools_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = specialist_pool_service.ListSpecialistPoolsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + specialist_pool_service.ListSpecialistPoolsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_specialist_pools(parent="parent_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].parent == "parent_value" + + +@pytest.mark.asyncio +async def test_list_specialist_pools_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_specialist_pools( + specialist_pool_service.ListSpecialistPoolsRequest(), parent="parent_value", + ) + + +def test_list_specialist_pools_pager(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + + metadata = () + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_specialist_pools(request={}) + + assert pager._metadata == metadata + + results = [i for i in pager] + assert len(results) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in results) + + +def test_list_specialist_pools_pages(): + client = SpecialistPoolServiceClient(credentials=credentials.AnonymousCredentials,) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = list(client.list_specialist_pools(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pager(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_specialist_pools(request={},) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, specialist_pool.SpecialistPool) for i in responses) + + +@pytest.mark.asyncio +async def test_list_specialist_pools_async_pages(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_specialist_pools), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + next_page_token="abc", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[], next_page_token="def", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[specialist_pool.SpecialistPool(),], + next_page_token="ghi", + ), + specialist_pool_service.ListSpecialistPoolsResponse( + specialist_pools=[ + specialist_pool.SpecialistPool(), + specialist_pool.SpecialistPool(), + ], + ), + RuntimeError, + ) + pages = [] + async for page_ in (await client.list_specialist_pools(request={})).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_delete_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_delete_specialist_pool_from_dict(): + test_delete_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.DeleteSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.DeleteSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_async_from_dict(): + await test_delete_specialist_pool_async(request_type=dict) + + +def test_delete_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.DeleteSpecialistPoolRequest() + request.name = "name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.delete_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ("x-goog-request-params", "name=name/value",) in kw["metadata"] + + +def test_delete_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +def test_delete_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_specialist_pool(name="name_value",) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].name == "name_value" + + +@pytest.mark.asyncio +async def test_delete_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_specialist_pool( + specialist_pool_service.DeleteSpecialistPoolRequest(), name="name_value", + ) + + +def test_update_specialist_pool( + transport: str = "grpc", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + + response = client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_specialist_pool_from_dict(): + test_update_specialist_pool(request_type=dict) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_async( + transport: str = "grpc_asyncio", + request_type=specialist_pool_service.UpdateSpecialistPoolRequest, +): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + + response = await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0] == specialist_pool_service.UpdateSpecialistPoolRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_async_from_dict(): + await test_update_specialist_pool_async(request_type=dict) + + +def test_update_specialist_pool_field_headers(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = "specialist_pool.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + + client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_specialist_pool_field_headers_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = specialist_pool_service.UpdateSpecialistPoolRequest() + request.specialist_pool.name = "specialist_pool.name/value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + + await client.update_specialist_pool(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "specialist_pool.name=specialist_pool.name/value", + ) in kw["metadata"] + + +def test_update_specialist_pool_flattened(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +def test_update_specialist_pool_flattened_error(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_specialist_pool), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_specialist_pool( + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + + assert args[0].specialist_pool == gca_specialist_pool.SpecialistPool( + name="name_value" + ) + + assert args[0].update_mask == field_mask.FieldMask(paths=["paths_value"]) + + +@pytest.mark.asyncio +async def test_update_specialist_pool_flattened_error_async(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_specialist_pool( + specialist_pool_service.UpdateSpecialistPoolRequest(), + specialist_pool=gca_specialist_pool.SpecialistPool(name="name_value"), + update_mask=field_mask.FieldMask(paths=["paths_value"]), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = SpecialistPoolServiceClient( + client_options={"scopes": ["1", "2"]}, transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + client = SpecialistPoolServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.SpecialistPoolServiceGrpcTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + credentials=credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + ) + assert isinstance(client.transport, transports.SpecialistPoolServiceGrpcTransport,) + + +def test_specialist_pool_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(exceptions.DuplicateCredentialArgs): + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_specialist_pool_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.SpecialistPoolServiceTransport( + credentials=credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_specialist_pool", + "get_specialist_pool", + "list_specialist_pools", + "delete_specialist_pool", + "update_specialist_pool", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + +def test_specialist_pool_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + auth, "load_credentials_from_file" + ) as load_creds, mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport( + credentials_file="credentials.json", quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +def test_specialist_pool_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(auth, "default") as adc, mock.patch( + "google.cloud.aiplatform_v1.services.specialist_pool_service.transports.SpecialistPoolServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (credentials.AnonymousCredentials(), None) + transport = transports.SpecialistPoolServiceTransport() + adc.assert_called_once() + + +def test_specialist_pool_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + SpecialistPoolServiceClient() + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id=None, + ) + + +def test_specialist_pool_service_transport_auth_adc(): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(auth, "default") as adc: + adc.return_value = (credentials.AnonymousCredentials(), None) + transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", quota_project_id="octopus" + ) + adc.assert_called_once_with( + scopes=("https://www.googleapis.com/auth/cloud-platform",), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( + transport_class, +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_specialist_pool_service_host_no_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:443" + + +def test_specialist_pool_service_host_with_port(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="aiplatform.googleapis.com:8000" + ), + ) + assert client.transport._host == "aiplatform.googleapis.com:8000" + + +def test_specialist_pool_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_specialist_pool_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_specialist_pool_service_grpc_lro_client(): + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_service_grpc_lro_async_client(): + client = SpecialistPoolServiceAsyncClient( + credentials=credentials.AnonymousCredentials(), transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance(transport.operations_client, operations_v1.OperationsAsyncClient,) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_specialist_pool_path(): + project = "squid" + location = "clam" + specialist_pool = "whelk" + + expected = "projects/{project}/locations/{location}/specialistPools/{specialist_pool}".format( + project=project, location=location, specialist_pool=specialist_pool, + ) + actual = SpecialistPoolServiceClient.specialist_pool_path( + project, location, specialist_pool + ) + assert expected == actual + + +def test_parse_specialist_pool_path(): + expected = { + "project": "octopus", + "location": "oyster", + "specialist_pool": "nudibranch", + } + path = SpecialistPoolServiceClient.specialist_pool_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_specialist_pool_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "cuttlefish" + + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = SpecialistPoolServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "mussel", + } + path = SpecialistPoolServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "winkle" + + expected = "folders/{folder}".format(folder=folder,) + actual = SpecialistPoolServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nautilus", + } + path = SpecialistPoolServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "scallop" + + expected = "organizations/{organization}".format(organization=organization,) + actual = SpecialistPoolServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "abalone", + } + path = SpecialistPoolServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "squid" + + expected = "projects/{project}".format(project=project,) + actual = SpecialistPoolServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "clam", + } + path = SpecialistPoolServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "whelk" + location = "octopus" + + expected = "projects/{project}/locations/{location}".format( + project=project, location=location, + ) + actual = SpecialistPoolServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + } + path = SpecialistPoolServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = SpecialistPoolServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_withDEFAULT_CLIENT_INFO(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = SpecialistPoolServiceClient( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.SpecialistPoolServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = SpecialistPoolServiceClient.get_transport_class() + transport = transport_class( + credentials=credentials.AnonymousCredentials(), client_info=client_info, + ) + prep.assert_called_once_with(client_info) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py index 51022d9fb7..fe6e04c2ec 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_dataset_service.py @@ -49,6 +49,7 @@ from google.cloud.aiplatform_v1beta1.types import dataset from google.cloud.aiplatform_v1beta1.types import dataset as gca_dataset from google.cloud.aiplatform_v1beta1.types import dataset_service +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import io from google.cloud.aiplatform_v1beta1.types import operation as gca_operation from google.longrunning import operations_pb2 @@ -102,8 +103,21 @@ def test__get_default_mtls_endpoint(): ) +def test_dataset_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = DatasetServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [DatasetServiceClient, DatasetServiceAsyncClient] + "client_class", [DatasetServiceClient, DatasetServiceAsyncClient,] ) def test_dataset_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -122,7 +136,10 @@ def test_dataset_service_client_from_service_account_file(client_class): def test_dataset_service_client_get_transport_class(): transport = DatasetServiceClient.get_transport_class() - assert transport == transports.DatasetServiceGrpcTransport + available_transports = [ + transports.DatasetServiceGrpcTransport, + ] + assert transport in available_transports transport = DatasetServiceClient.get_transport_class("grpc") assert transport == transports.DatasetServiceGrpcTransport @@ -173,7 +190,7 @@ def test_dataset_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +206,7 @@ def test_dataset_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -205,7 +222,7 @@ def test_dataset_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -233,7 +250,7 @@ def test_dataset_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -284,29 +301,25 @@ def test_dataset_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -315,66 +328,53 @@ def test_dataset_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -400,7 +400,7 @@ def test_dataset_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -430,7 +430,7 @@ def test_dataset_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -449,7 +449,7 @@ def test_dataset_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -3099,6 +3099,51 @@ def test_dataset_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.DatasetServiceGrpcTransport, + transports.DatasetServiceGrpcAsyncIOTransport, + ], +) +def test_dataset_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_dataset_service_host_no_port(): client = DatasetServiceClient( credentials=credentials.AnonymousCredentials(), @@ -3120,7 +3165,7 @@ def test_dataset_service_host_with_port(): def test_dataset_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcTransport( @@ -3132,7 +3177,7 @@ def test_dataset_service_grpc_transport_channel(): def test_dataset_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.DatasetServiceGrpcAsyncIOTransport( @@ -3143,6 +3188,8 @@ def test_dataset_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -3157,7 +3204,7 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3186,11 +3233,17 @@ def test_dataset_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -3206,7 +3259,7 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -3227,6 +3280,10 @@ def test_dataset_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py index 93c35a7a2a..237d6d9268 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_endpoint_service.py @@ -44,6 +44,7 @@ from google.cloud.aiplatform_v1beta1.services.endpoint_service import pagers from google.cloud.aiplatform_v1beta1.services.endpoint_service import transports from google.cloud.aiplatform_v1beta1.types import accelerator_type +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import endpoint from google.cloud.aiplatform_v1beta1.types import endpoint as gca_endpoint from google.cloud.aiplatform_v1beta1.types import endpoint_service @@ -102,8 +103,21 @@ def test__get_default_mtls_endpoint(): ) +def test_endpoint_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = EndpointServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [EndpointServiceClient, EndpointServiceAsyncClient] + "client_class", [EndpointServiceClient, EndpointServiceAsyncClient,] ) def test_endpoint_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -122,7 +136,10 @@ def test_endpoint_service_client_from_service_account_file(client_class): def test_endpoint_service_client_get_transport_class(): transport = EndpointServiceClient.get_transport_class() - assert transport == transports.EndpointServiceGrpcTransport + available_transports = [ + transports.EndpointServiceGrpcTransport, + ] + assert transport in available_transports transport = EndpointServiceClient.get_transport_class("grpc") assert transport == transports.EndpointServiceGrpcTransport @@ -173,7 +190,7 @@ def test_endpoint_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +206,7 @@ def test_endpoint_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -205,7 +222,7 @@ def test_endpoint_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -233,7 +250,7 @@ def test_endpoint_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -294,29 +311,25 @@ def test_endpoint_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -325,66 +338,53 @@ def test_endpoint_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -410,7 +410,7 @@ def test_endpoint_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -440,7 +440,7 @@ def test_endpoint_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -459,7 +459,7 @@ def test_endpoint_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -2252,6 +2252,51 @@ def test_endpoint_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.EndpointServiceGrpcTransport, + transports.EndpointServiceGrpcAsyncIOTransport, + ], +) +def test_endpoint_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_endpoint_service_host_no_port(): client = EndpointServiceClient( credentials=credentials.AnonymousCredentials(), @@ -2273,7 +2318,7 @@ def test_endpoint_service_host_with_port(): def test_endpoint_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcTransport( @@ -2285,7 +2330,7 @@ def test_endpoint_service_grpc_transport_channel(): def test_endpoint_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.EndpointServiceGrpcAsyncIOTransport( @@ -2296,6 +2341,8 @@ def test_endpoint_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -2310,7 +2357,7 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2339,11 +2386,17 @@ def test_endpoint_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -2359,7 +2412,7 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -2380,6 +2433,10 @@ def test_endpoint_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py index f08d84bd2f..67b1c6830f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_job_service.py @@ -51,6 +51,7 @@ from google.cloud.aiplatform_v1beta1.types import ( data_labeling_job as gca_data_labeling_job, ) +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata from google.cloud.aiplatform_v1beta1.types import hyperparameter_tuning_job @@ -116,7 +117,20 @@ def test__get_default_mtls_endpoint(): assert JobServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient]) +def test_job_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = JobServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize("client_class", [JobServiceClient, JobServiceAsyncClient,]) def test_job_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -134,7 +148,10 @@ def test_job_service_client_from_service_account_file(client_class): def test_job_service_client_get_transport_class(): transport = JobServiceClient.get_transport_class() - assert transport == transports.JobServiceGrpcTransport + available_transports = [ + transports.JobServiceGrpcTransport, + ] + assert transport in available_transports transport = JobServiceClient.get_transport_class("grpc") assert transport == transports.JobServiceGrpcTransport @@ -183,7 +200,7 @@ def test_job_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -199,7 +216,7 @@ def test_job_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -215,7 +232,7 @@ def test_job_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -243,7 +260,7 @@ def test_job_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -292,29 +309,25 @@ def test_job_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -323,66 +336,53 @@ def test_job_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -408,7 +408,7 @@ def test_job_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -438,7 +438,7 @@ def test_job_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -455,7 +455,7 @@ def test_job_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -5541,7 +5541,7 @@ def test_transport_get_channel(): @pytest.mark.parametrize( "transport_class", - [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport,], ) def test_transport_adc(transport_class): # Test default credentials are used if not provided. @@ -5665,6 +5665,48 @@ def test_job_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], +) +def test_job_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_job_service_host_no_port(): client = JobServiceClient( credentials=credentials.AnonymousCredentials(), @@ -5686,7 +5728,7 @@ def test_job_service_host_with_port(): def test_job_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcTransport( @@ -5698,7 +5740,7 @@ def test_job_service_grpc_transport_channel(): def test_job_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.JobServiceGrpcAsyncIOTransport( @@ -5709,6 +5751,8 @@ def test_job_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], @@ -5718,7 +5762,7 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -5747,11 +5791,17 @@ def test_job_service_transport_channel_mtls_with_client_cert_source(transport_cl scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.JobServiceGrpcTransport, transports.JobServiceGrpcAsyncIOTransport], @@ -5764,7 +5814,7 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -5785,6 +5835,10 @@ def test_job_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel @@ -5973,8 +6027,35 @@ def test_parse_model_path(): assert expected == actual +def test_trial_path(): + project = "squid" + location = "clam" + study = "whelk" + trial = "octopus" + + expected = "projects/{project}/locations/{location}/studies/{study}/trials/{trial}".format( + project=project, location=location, study=study, trial=trial, + ) + actual = JobServiceClient.trial_path(project, location, study, trial) + assert expected == actual + + +def test_parse_trial_path(): + expected = { + "project": "oyster", + "location": "nudibranch", + "study": "cuttlefish", + "trial": "mussel", + } + path = JobServiceClient.trial_path(**expected) + + # Check that the path construction is reversible. + actual = JobServiceClient.parse_trial_path(path) + assert expected == actual + + def test_common_billing_account_path(): - billing_account = "squid" + billing_account = "winkle" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, @@ -5985,7 +6066,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "clam", + "billing_account": "nautilus", } path = JobServiceClient.common_billing_account_path(**expected) @@ -5995,7 +6076,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "whelk" + folder = "scallop" expected = "folders/{folder}".format(folder=folder,) actual = JobServiceClient.common_folder_path(folder) @@ -6004,7 +6085,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "octopus", + "folder": "abalone", } path = JobServiceClient.common_folder_path(**expected) @@ -6014,7 +6095,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "oyster" + organization = "squid" expected = "organizations/{organization}".format(organization=organization,) actual = JobServiceClient.common_organization_path(organization) @@ -6023,7 +6104,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "nudibranch", + "organization": "clam", } path = JobServiceClient.common_organization_path(**expected) @@ -6033,7 +6114,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "cuttlefish" + project = "whelk" expected = "projects/{project}".format(project=project,) actual = JobServiceClient.common_project_path(project) @@ -6042,7 +6123,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "mussel", + "project": "octopus", } path = JobServiceClient.common_project_path(**expected) @@ -6052,8 +6133,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "winkle" - location = "nautilus" + project = "oyster" + location = "nudibranch" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -6064,8 +6145,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "scallop", - "location": "abalone", + "project": "cuttlefish", + "location": "mussel", } path = JobServiceClient.common_location_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py index 85e6a2d362..8594354c88 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_migration_service.py @@ -94,8 +94,21 @@ def test__get_default_mtls_endpoint(): ) +def test_migration_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = MigrationServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [MigrationServiceClient, MigrationServiceAsyncClient] + "client_class", [MigrationServiceClient, MigrationServiceAsyncClient,] ) def test_migration_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -114,7 +127,10 @@ def test_migration_service_client_from_service_account_file(client_class): def test_migration_service_client_get_transport_class(): transport = MigrationServiceClient.get_transport_class() - assert transport == transports.MigrationServiceGrpcTransport + available_transports = [ + transports.MigrationServiceGrpcTransport, + ] + assert transport in available_transports transport = MigrationServiceClient.get_transport_class("grpc") assert transport == transports.MigrationServiceGrpcTransport @@ -165,7 +181,7 @@ def test_migration_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -181,7 +197,7 @@ def test_migration_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -197,7 +213,7 @@ def test_migration_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -225,7 +241,7 @@ def test_migration_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -286,29 +302,25 @@ def test_migration_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -317,66 +329,53 @@ def test_migration_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -402,7 +401,7 @@ def test_migration_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -432,7 +431,7 @@ def test_migration_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -451,7 +450,7 @@ def test_migration_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -1266,6 +1265,51 @@ def test_migration_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.MigrationServiceGrpcTransport, + transports.MigrationServiceGrpcAsyncIOTransport, + ], +) +def test_migration_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_migration_service_host_no_port(): client = MigrationServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1287,7 +1331,7 @@ def test_migration_service_host_with_port(): def test_migration_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcTransport( @@ -1299,7 +1343,7 @@ def test_migration_service_grpc_transport_channel(): def test_migration_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.MigrationServiceGrpcAsyncIOTransport( @@ -1310,6 +1354,8 @@ def test_migration_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1324,7 +1370,7 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1353,11 +1399,17 @@ def test_migration_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1373,7 +1425,7 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -1394,6 +1446,10 @@ def test_migration_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py index d05698a46a..05bb815f3f 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_model_service.py @@ -42,6 +42,7 @@ from google.cloud.aiplatform_v1beta1.services.model_service import pagers from google.cloud.aiplatform_v1beta1.services.model_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import env_var from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata @@ -100,7 +101,20 @@ def test__get_default_mtls_endpoint(): assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi -@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +def test_model_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = ModelServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient,]) def test_model_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() with mock.patch.object( @@ -118,7 +132,10 @@ def test_model_service_client_from_service_account_file(client_class): def test_model_service_client_get_transport_class(): transport = ModelServiceClient.get_transport_class() - assert transport == transports.ModelServiceGrpcTransport + available_transports = [ + transports.ModelServiceGrpcTransport, + ] + assert transport in available_transports transport = ModelServiceClient.get_transport_class("grpc") assert transport == transports.ModelServiceGrpcTransport @@ -167,7 +184,7 @@ def test_model_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -183,7 +200,7 @@ def test_model_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -199,7 +216,7 @@ def test_model_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -227,7 +244,7 @@ def test_model_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +293,25 @@ def test_model_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +320,53 @@ def test_model_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -392,7 +392,7 @@ def test_model_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -422,7 +422,7 @@ def test_model_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -439,7 +439,7 @@ def test_model_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -3149,7 +3149,10 @@ def test_transport_get_channel(): @pytest.mark.parametrize( "transport_class", - [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], ) def test_transport_adc(transport_class): # Test default credentials are used if not provided. @@ -3263,6 +3266,48 @@ def test_model_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_model_service_host_no_port(): client = ModelServiceClient( credentials=credentials.AnonymousCredentials(), @@ -3284,7 +3329,7 @@ def test_model_service_host_with_port(): def test_model_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcTransport( @@ -3296,7 +3341,7 @@ def test_model_service_grpc_transport_channel(): def test_model_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.ModelServiceGrpcAsyncIOTransport( @@ -3307,6 +3352,8 @@ def test_model_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], @@ -3316,7 +3363,7 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -3345,11 +3392,17 @@ def test_model_service_transport_channel_mtls_with_client_cert_source(transport_ scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], @@ -3362,7 +3415,7 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -3383,6 +3436,10 @@ def test_model_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py index ada82b91c0..8135921566 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_pipeline_service.py @@ -44,6 +44,7 @@ from google.cloud.aiplatform_v1beta1.services.pipeline_service import pagers from google.cloud.aiplatform_v1beta1.services.pipeline_service import transports from google.cloud.aiplatform_v1beta1.types import deployed_model_ref +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import env_var from google.cloud.aiplatform_v1beta1.types import explanation from google.cloud.aiplatform_v1beta1.types import explanation_metadata @@ -109,8 +110,21 @@ def test__get_default_mtls_endpoint(): ) +def test_pipeline_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = PipelineServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [PipelineServiceClient, PipelineServiceAsyncClient] + "client_class", [PipelineServiceClient, PipelineServiceAsyncClient,] ) def test_pipeline_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -129,7 +143,10 @@ def test_pipeline_service_client_from_service_account_file(client_class): def test_pipeline_service_client_get_transport_class(): transport = PipelineServiceClient.get_transport_class() - assert transport == transports.PipelineServiceGrpcTransport + available_transports = [ + transports.PipelineServiceGrpcTransport, + ] + assert transport in available_transports transport = PipelineServiceClient.get_transport_class("grpc") assert transport == transports.PipelineServiceGrpcTransport @@ -180,7 +197,7 @@ def test_pipeline_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -196,7 +213,7 @@ def test_pipeline_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +229,7 @@ def test_pipeline_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,7 +257,7 @@ def test_pipeline_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -301,29 +318,25 @@ def test_pipeline_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -332,66 +345,53 @@ def test_pipeline_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -417,7 +417,7 @@ def test_pipeline_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -447,7 +447,7 @@ def test_pipeline_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -466,7 +466,7 @@ def test_pipeline_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -1880,6 +1880,51 @@ def test_pipeline_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PipelineServiceGrpcTransport, + transports.PipelineServiceGrpcAsyncIOTransport, + ], +) +def test_pipeline_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_pipeline_service_host_no_port(): client = PipelineServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1901,7 +1946,7 @@ def test_pipeline_service_host_with_port(): def test_pipeline_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcTransport( @@ -1913,7 +1958,7 @@ def test_pipeline_service_grpc_transport_channel(): def test_pipeline_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PipelineServiceGrpcAsyncIOTransport( @@ -1924,6 +1969,8 @@ def test_pipeline_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1938,7 +1985,7 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1967,11 +2014,17 @@ def test_pipeline_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1987,7 +2040,7 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -2008,6 +2061,10 @@ def test_pipeline_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py index e47e0f62c5..311f03d5b7 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py @@ -90,8 +90,21 @@ def test__get_default_mtls_endpoint(): ) +def test_prediction_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = PredictionServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient,] ) def test_prediction_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -110,7 +123,10 @@ def test_prediction_service_client_from_service_account_file(client_class): def test_prediction_service_client_get_transport_class(): transport = PredictionServiceClient.get_transport_class() - assert transport == transports.PredictionServiceGrpcTransport + available_transports = [ + transports.PredictionServiceGrpcTransport, + ] + assert transport in available_transports transport = PredictionServiceClient.get_transport_class("grpc") assert transport == transports.PredictionServiceGrpcTransport @@ -161,7 +177,7 @@ def test_prediction_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -177,7 +193,7 @@ def test_prediction_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -193,7 +209,7 @@ def test_prediction_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -221,7 +237,7 @@ def test_prediction_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -282,29 +298,25 @@ def test_prediction_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -313,66 +325,53 @@ def test_prediction_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -398,7 +397,7 @@ def test_prediction_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -428,7 +427,7 @@ def test_prediction_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -447,7 +446,7 @@ def test_prediction_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -1087,6 +1086,51 @@ def test_prediction_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_prediction_service_host_no_port(): client = PredictionServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1108,7 +1152,7 @@ def test_prediction_service_host_with_port(): def test_prediction_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PredictionServiceGrpcTransport( @@ -1120,7 +1164,7 @@ def test_prediction_service_grpc_transport_channel(): def test_prediction_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.PredictionServiceGrpcAsyncIOTransport( @@ -1131,6 +1175,8 @@ def test_prediction_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1145,7 +1191,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -1174,11 +1220,17 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -1194,7 +1246,7 @@ def test_prediction_service_transport_channel_mtls_with_adc(transport_class): ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -1215,6 +1267,10 @@ def test_prediction_service_transport_channel_mtls_with_adc(transport_class): scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py index 6c1061d588..c91839ff1a 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_specialist_pool_service.py @@ -97,8 +97,21 @@ def test__get_default_mtls_endpoint(): ) +def test_specialist_pool_service_client_from_service_account_info(): + creds = credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = SpecialistPoolServiceClient.from_service_account_info(info) + assert client.transport._credentials == creds + + assert client.transport._host == "aiplatform.googleapis.com:443" + + @pytest.mark.parametrize( - "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient] + "client_class", [SpecialistPoolServiceClient, SpecialistPoolServiceAsyncClient,] ) def test_specialist_pool_service_client_from_service_account_file(client_class): creds = credentials.AnonymousCredentials() @@ -117,7 +130,10 @@ def test_specialist_pool_service_client_from_service_account_file(client_class): def test_specialist_pool_service_client_get_transport_class(): transport = SpecialistPoolServiceClient.get_transport_class() - assert transport == transports.SpecialistPoolServiceGrpcTransport + available_transports = [ + transports.SpecialistPoolServiceGrpcTransport, + ] + assert transport in available_transports transport = SpecialistPoolServiceClient.get_transport_class("grpc") assert transport == transports.SpecialistPoolServiceGrpcTransport @@ -172,7 +188,7 @@ def test_specialist_pool_service_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +204,7 @@ def test_specialist_pool_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +220,7 @@ def test_specialist_pool_service_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +248,7 @@ def test_specialist_pool_service_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -293,29 +309,25 @@ def test_specialist_pool_service_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -324,66 +336,53 @@ def test_specialist_pool_service_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -413,7 +412,7 @@ def test_specialist_pool_service_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -447,7 +446,7 @@ def test_specialist_pool_service_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -466,7 +465,7 @@ def test_specialist_pool_service_client_client_options_from_dict(): credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -1960,6 +1959,53 @@ def test_specialist_pool_service_transport_auth_adc(): ) +@pytest.mark.parametrize( + "transport_class", + [ + transports.SpecialistPoolServiceGrpcTransport, + transports.SpecialistPoolServiceGrpcAsyncIOTransport, + ], +) +def test_specialist_pool_service_grpc_transport_client_cert_source_for_mtls( + transport_class, +): + cred = credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=("https://www.googleapis.com/auth/cloud-platform",), + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + def test_specialist_pool_service_host_no_port(): client = SpecialistPoolServiceClient( credentials=credentials.AnonymousCredentials(), @@ -1981,7 +2027,7 @@ def test_specialist_pool_service_host_with_port(): def test_specialist_pool_service_grpc_transport_channel(): - channel = grpc.insecure_channel("http://localhost/") + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcTransport( @@ -1993,7 +2039,7 @@ def test_specialist_pool_service_grpc_transport_channel(): def test_specialist_pool_service_grpc_asyncio_transport_channel(): - channel = aio.insecure_channel("http://localhost/") + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) # Check that channel is used if provided. transport = transports.SpecialistPoolServiceGrpcAsyncIOTransport( @@ -2004,6 +2050,8 @@ def test_specialist_pool_service_grpc_asyncio_transport_channel(): assert transport._ssl_channel_credentials == None +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -2018,7 +2066,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( "grpc.ssl_channel_credentials", autospec=True ) as grpc_ssl_channel_cred: with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_ssl_cred = mock.Mock() grpc_ssl_channel_cred.return_value = mock_ssl_cred @@ -2047,11 +2095,17 @@ def test_specialist_pool_service_transport_channel_mtls_with_client_cert_source( scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel assert transport._ssl_channel_credentials == mock_ssl_cred +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. @pytest.mark.parametrize( "transport_class", [ @@ -2067,7 +2121,7 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), ): with mock.patch.object( - transport_class, "create_channel", autospec=True + transport_class, "create_channel" ) as grpc_create_channel: mock_grpc_channel = mock.Mock() grpc_create_channel.return_value = mock_grpc_channel @@ -2088,6 +2142,10 @@ def test_specialist_pool_service_transport_channel_mtls_with_adc(transport_class scopes=("https://www.googleapis.com/auth/cloud-platform",), ssl_credentials=mock_ssl_cred, quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], ) assert transport.grpc_channel == mock_grpc_channel