/
jobs.py
1691 lines (1450 loc) · 71.5 KB
/
jobs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Iterable, Optional, Union, Sequence, Dict, List
import abc
import copy
import datetime
import time
from google.cloud import storage
from google.cloud import bigquery
from google.auth import credentials as auth_credentials
from google.protobuf import duration_pb2 # type: ignore
from google.rpc import status_pb2
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import compat
from google.cloud.aiplatform import constants
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform import hyperparameter_tuning
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.utils import console_utils
from google.cloud.aiplatform.utils import source_utils
from google.cloud.aiplatform.utils import worker_spec_utils
from google.cloud.aiplatform.compat.services import job_service_client
from google.cloud.aiplatform.compat.types import (
batch_prediction_job as gca_bp_job_compat,
batch_prediction_job_v1 as gca_bp_job_v1,
batch_prediction_job_v1beta1 as gca_bp_job_v1beta1,
completion_stats as gca_completion_stats,
custom_job as gca_custom_job_compat,
custom_job_v1beta1 as gca_custom_job_v1beta1,
explanation_v1beta1 as gca_explanation_v1beta1,
io as gca_io_compat,
io_v1beta1 as gca_io_v1beta1,
job_state as gca_job_state,
hyperparameter_tuning_job as gca_hyperparameter_tuning_job_compat,
hyperparameter_tuning_job_v1beta1 as gca_hyperparameter_tuning_job_v1beta1,
machine_resources as gca_machine_resources_compat,
machine_resources_v1beta1 as gca_machine_resources_v1beta1,
study as gca_study_compat,
)
_LOGGER = base.Logger(__name__)
_JOB_COMPLETE_STATES = (
gca_job_state.JobState.JOB_STATE_SUCCEEDED,
gca_job_state.JobState.JOB_STATE_FAILED,
gca_job_state.JobState.JOB_STATE_CANCELLED,
gca_job_state.JobState.JOB_STATE_PAUSED,
)
_JOB_ERROR_STATES = (
gca_job_state.JobState.JOB_STATE_FAILED,
gca_job_state.JobState.JOB_STATE_CANCELLED,
)
class _Job(base.VertexAiResourceNounWithFutureManager):
"""Class that represents a general Job resource in Vertex AI.
Cannot be directly instantiated.
Serves as base class to specific Job types, i.e. BatchPredictionJob or
DataLabelingJob to re-use shared functionality.
Subclasses requires one class attribute:
_getter_method (str): The name of JobServiceClient getter method for specific
Job type, i.e. 'get_custom_job' for CustomJob
_cancel_method (str): The name of the specific JobServiceClient cancel method
_delete_method (str): The name of the specific JobServiceClient delete method
"""
client_class = utils.JobClientWithOverride
_is_client_prediction_client = False
def __init__(
self,
job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves Job subclass resource by calling a subclass-specific getter
method.
Args:
job_name (str):
Required. A fully-qualified job resource name or job ID.
Example: "projects/123/locations/us-central1/batchPredictionJobs/456" or
"456" when project, location and job_type are initialized or passed.
project: Optional[str] = None,
Optional project to retrieve Job subclass from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve Job subclass from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
super().__init__(
project=project,
location=location,
credentials=credentials,
resource_name=job_name,
)
self._gca_resource = self._get_gca_resource(resource_name=job_name)
@property
def state(self) -> gca_job_state.JobState:
"""Fetch Job again and return the current JobState.
Returns:
state (job_state.JobState):
Enum that describes the state of a Vertex AI job.
"""
# Fetch the Job again for most up-to-date job state
self._sync_gca_resource()
return self._gca_resource.state
@property
def start_time(self) -> Optional[datetime.datetime]:
"""Time when the Job resource entered the `JOB_STATE_RUNNING` for the
first time."""
self._sync_gca_resource()
return getattr(self._gca_resource, "start_time")
@property
def end_time(self) -> Optional[datetime.datetime]:
"""Time when the Job resource entered the `JOB_STATE_SUCCEEDED`,
`JOB_STATE_FAILED`, or `JOB_STATE_CANCELLED` state."""
self._sync_gca_resource()
return getattr(self._gca_resource, "end_time")
@property
def error(self) -> Optional[status_pb2.Status]:
"""Detailed error info for this Job resource. Only populated when the
Job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`."""
self._sync_gca_resource()
return getattr(self._gca_resource, "error")
@property
@abc.abstractmethod
def _job_type(cls) -> str:
"""Job type."""
pass
@property
@abc.abstractmethod
def _cancel_method(cls) -> str:
"""Name of cancellation method for cancelling the specific job type."""
pass
def _dashboard_uri(self) -> Optional[str]:
"""Helper method to compose the dashboard uri where job can be
viewed."""
fields = utils.extract_fields_from_resource_name(self.resource_name)
url = f"https://console.cloud.google.com/ai/platform/locations/{fields.location}/{self._job_type}/{fields.id}?project={fields.project}"
return url
def _block_until_complete(self):
"""Helper method to block and check on job until complete.
Raises:
RuntimeError: If job failed or cancelled.
"""
# Used these numbers so failures surface fast
wait = 5 # start at five seconds
log_wait = 5
max_wait = 60 * 5 # 5 minute wait
multiplier = 2 # scale wait by 2 every iteration
previous_time = time.time()
while self.state not in _JOB_COMPLETE_STATES:
current_time = time.time()
if current_time - previous_time >= log_wait:
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
log_wait = min(log_wait * multiplier, max_wait)
previous_time = current_time
time.sleep(wait)
_LOGGER.info(
"%s %s current state:\n%s"
% (
self.__class__.__name__,
self._gca_resource.name,
self._gca_resource.state,
)
)
# Error is only populated when the job state is
# JOB_STATE_FAILED or JOB_STATE_CANCELLED.
if self._gca_resource.state in _JOB_ERROR_STATES:
raise RuntimeError("Job failed with:\n%s" % self._gca_resource.error)
else:
_LOGGER.log_action_completed_against_resource("run", "completed", self)
@classmethod
def list(
cls,
filter: Optional[str] = None,
order_by: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> List[base.VertexAiResourceNoun]:
"""List all instances of this Job Resource.
Example Usage:
aiplatform.BatchPredictionJobs.list(
filter='state="JOB_STATE_SUCCEEDED" AND display_name="my_job"',
)
Args:
filter (str):
Optional. An expression for filtering the results of the request.
For field names both snake_case and camelCase are supported.
order_by (str):
Optional. A comma-separated list of fields to order by, sorted in
ascending order. Use "desc" after a field name for descending.
Supported fields: `display_name`, `create_time`, `update_time`
project (str):
Optional. Project to retrieve list from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional. Location to retrieve list from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Optional. Custom credentials to use to retrieve list. Overrides
credentials set in aiplatform.init.
Returns:
List[VertexAiResourceNoun] - A list of Job resource objects
"""
return cls._list_with_local_order(
filter=filter,
order_by=order_by,
project=project,
location=location,
credentials=credentials,
)
def cancel(self) -> None:
"""Cancels this Job.
Success of cancellation is not guaranteed. Use `Job.state`
property to verify if cancellation was successful.
"""
_LOGGER.log_action_start_against_resource("Cancelling", "run", self)
getattr(self.api_client, self._cancel_method)(name=self.resource_name)
class BatchPredictionJob(_Job):
_resource_noun = "batchPredictionJobs"
_getter_method = "get_batch_prediction_job"
_list_method = "list_batch_prediction_jobs"
_cancel_method = "cancel_batch_prediction_job"
_delete_method = "delete_batch_prediction_job"
_job_type = "batch-predictions"
def __init__(
self,
batch_prediction_job_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Retrieves a BatchPredictionJob resource and instantiates its
representation.
Args:
batch_prediction_job_name (str):
Required. A fully-qualified BatchPredictionJob resource name or ID.
Example: "projects/.../locations/.../batchPredictionJobs/456" or
"456" when project and location are initialized or passed.
project: Optional[str] = None,
Optional project to retrieve BatchPredictionJob from. If not set,
project set in aiplatform.init will be used.
location: Optional[str] = None,
Optional location to retrieve BatchPredictionJob from. If not set,
location set in aiplatform.init will be used.
credentials: Optional[auth_credentials.Credentials] = None,
Custom credentials to use. If not set, credentials set in
aiplatform.init will be used.
"""
super().__init__(
job_name=batch_prediction_job_name,
project=project,
location=location,
credentials=credentials,
)
@property
def output_info(self,) -> Optional[aiplatform.gapic.BatchPredictionJob.OutputInfo]:
"""Information describing the output of this job, including output location
into which prediction output is written.
This is only available for batch predicition jobs that have run successfully.
"""
self._assert_gca_resource_is_available()
return self._gca_resource.output_info
@property
def partial_failures(self) -> Optional[Sequence[status_pb2.Status]]:
"""Partial failures encountered. For example, single files that can't be read.
This field never exceeds 20 entries. Status details fields contain standard
GCP error details."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "partial_failures")
@property
def completion_stats(self) -> Optional[gca_completion_stats.CompletionStats]:
"""Statistics on completed and failed prediction instances."""
self._assert_gca_resource_is_available()
return getattr(self._gca_resource, "completion_stats")
@classmethod
def create(
cls,
job_display_name: str,
model_name: str,
instances_format: str = "jsonl",
predictions_format: str = "jsonl",
gcs_source: Optional[Union[str, Sequence[str]]] = None,
bigquery_source: Optional[str] = None,
gcs_destination_prefix: Optional[str] = None,
bigquery_destination_prefix: Optional[str] = None,
model_parameters: Optional[Dict] = None,
machine_type: Optional[str] = None,
accelerator_type: Optional[str] = None,
accelerator_count: Optional[int] = None,
starting_replica_count: Optional[int] = None,
max_replica_count: Optional[int] = None,
generate_explanation: Optional[bool] = False,
explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None,
explanation_parameters: Optional[
"aiplatform.explain.ExplanationParameters"
] = None,
labels: Optional[Dict[str, str]] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
encryption_spec_key_name: Optional[str] = None,
sync: bool = True,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Args:
job_display_name (str):
Required. The user-defined name of the BatchPredictionJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
model_name (str):
Required. A fully-qualified model resource name or model ID.
Example: "projects/123/locations/us-central1/models/456" or
"456" when project and location are initialized or passed.
instances_format (str):
Required. The format in which instances are given, must be one
of "jsonl", "csv", "bigquery", "tf-record", "tf-record-gzip",
or "file-list". Default is "jsonl" when using `gcs_source`. If a
`bigquery_source` is provided, this is overridden to "bigquery".
predictions_format (str):
Required. The format in which Vertex AI gives the
predictions, must be one of "jsonl", "csv", or "bigquery".
Default is "jsonl" when using `gcs_destination_prefix`. If a
`bigquery_destination_prefix` is provided, this is overridden to
"bigquery".
gcs_source (Optional[Sequence[str]]):
Google Cloud Storage URI(-s) to your instances to run
batch prediction on. They must match `instances_format`.
May contain wildcards. For more information on wildcards, see
https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
bigquery_source (Optional[str]):
BigQuery URI to a table, up to 2000 characters long. For example:
`bq://projectId.bqDatasetId.bqTableId`
gcs_destination_prefix (Optional[str]):
The Google Cloud Storage location of the directory where the
output is to be written to. In the given directory a new
directory is created. Its name is
``prediction-<model-display-name>-<job-create-time>``, where
timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format.
Inside of it files ``predictions_0001.<extension>``,
``predictions_0002.<extension>``, ...,
``predictions_N.<extension>`` are created where
``<extension>`` depends on chosen ``predictions_format``,
and N may equal 0001 and depends on the total number of
successfully predicted instances. If the Model has both
``instance`` and ``prediction`` schemata defined then each such
file contains predictions as per the ``predictions_format``.
If prediction for any instance failed (partially or
completely), then an additional ``errors_0001.<extension>``,
``errors_0002.<extension>``,..., ``errors_N.<extension>``
files are created (N depends on total number of failed
predictions). These files contain the failed instances, as
per their schema, followed by an additional ``error`` field
which as value has ```google.rpc.Status`` <Status>`__
containing only ``code`` and ``message`` fields.
bigquery_destination_prefix (Optional[str]):
The BigQuery project location where the output is to be
written to. In the given project a new dataset is created
with name
``prediction_<model-display-name>_<job-create-time>`` where
is made BigQuery-dataset-name compatible (for example, most
special characters become underscores), and timestamp is in
YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the
dataset two tables will be created, ``predictions``, and
``errors``. If the Model has both ``instance`` and ``prediction``
schemata defined then the tables have columns as follows:
The ``predictions`` table contains instances for which the
prediction succeeded, it has columns as per a concatenation
of the Model's instance and prediction schemata. The
``errors`` table contains rows for which the prediction has
failed, it has instance columns, as per the instance schema,
followed by a single "errors" column, which as values has
```google.rpc.Status`` <Status>`__ represented as a STRUCT,
and containing only ``code`` and ``message``.
model_parameters (Optional[Dict]):
The parameters that govern the predictions. The schema of
the parameters may be specified via the Model's `parameters_schema_uri`.
machine_type (Optional[str]):
The type of machine for running batch prediction on
dedicated resources. Not specifying machine type will result in
batch prediction job being run with automatic resources.
accelerator_type (Optional[str]):
The type of accelerator(s) that may be attached
to the machine as per `accelerator_count`. Only used if
`machine_type` is set.
accelerator_count (Optional[int]):
The number of accelerators to attach to the
`machine_type`. Only used if `machine_type` is set.
starting_replica_count (Optional[int]):
The number of machine replicas used at the start of the batch
operation. If not set, Vertex AI decides starting number, not
greater than `max_replica_count`. Only used if `machine_type` is
set.
max_replica_count (Optional[int]):
The maximum number of machine replicas the batch operation may
be scaled to. Only used if `machine_type` is set.
Default is 10.
generate_explanation (bool):
Optional. Generate explanation along with the batch prediction
results. This will cause the batch prediction output to include
explanations based on the `prediction_format`:
- `bigquery`: output includes a column named `explanation`. The value
is a struct that conforms to the [aiplatform.gapic.Explanation] object.
- `jsonl`: The JSON objects on each line include an additional entry
keyed `explanation`. The value of the entry is a JSON object that
conforms to the [aiplatform.gapic.Explanation] object.
- `csv`: Generating explanations for CSV format is not supported.
explanation_metadata (aiplatform.explain.ExplanationMetadata):
Optional. Explanation metadata configuration for this BatchPredictionJob.
Can be specified only if `generate_explanation` is set to `True`.
This value overrides the value of `Model.explanation_metadata`.
All fields of `explanation_metadata` are optional in the request. If
a field of the `explanation_metadata` object is not populated, the
corresponding field of the `Model.explanation_metadata` object is inherited.
For more details, see `Ref docs <http://tinyurl.com/1igh60kt>`
explanation_parameters (aiplatform.explain.ExplanationParameters):
Optional. Parameters to configure explaining for Model's predictions.
Can be specified only if `generate_explanation` is set to `True`.
This value overrides the value of `Model.explanation_parameters`.
All fields of `explanation_parameters` are optional in the request. If
a field of the `explanation_parameters` object is not populated, the
corresponding field of the `Model.explanation_parameters` object is inherited.
For more details, see `Ref docs <http://tinyurl.com/1an4zake>`
labels (Dict[str, str]):
Optional. The labels with user-defined metadata to organize your
BatchPredictionJobs. Label keys and values can be no longer than
64 characters (Unicode codepoints), can only contain lowercase
letters, numeric characters, underscores and dashes.
International characters are allowed. See https://goo.gl/xmQnxf
for more information and examples of labels.
credentials (Optional[auth_credentials.Credentials]):
Custom credentials to use to create this batch prediction
job. Overrides credentials set in aiplatform.init.
encryption_spec_key_name (Optional[str]):
Optional. The Cloud KMS resource identifier of the customer
managed encryption key used to protect the job. Has the
form:
``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``.
The key needs to be in the same region as where the compute
resource is created.
If this is set, then all
resources created by the BatchPredictionJob will
be encrypted with the provided encryption key.
Overrides encryption_spec_key_name set in aiplatform.init.
sync (bool):
Whether to execute this method synchronously. If False, this method
will be executed in concurrent Future and any downstream object will
be immediately returned and synced when the Future has completed.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
"""
utils.validate_display_name(job_display_name)
if labels:
utils.validate_labels(labels)
model_name = utils.full_resource_name(
resource_name=model_name,
resource_noun="models",
project=project,
location=location,
)
# Raise error if both or neither source URIs are provided
if bool(gcs_source) == bool(bigquery_source):
raise ValueError(
"Please provide either a gcs_source or bigquery_source, "
"but not both."
)
# Raise error if both or neither destination prefixes are provided
if bool(gcs_destination_prefix) == bool(bigquery_destination_prefix):
raise ValueError(
"Please provide either a gcs_destination_prefix or "
"bigquery_destination_prefix, but not both."
)
# Raise error if unsupported instance format is provided
if instances_format not in constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted instances format "
f"type. Please choose from: {constants.BATCH_PREDICTION_INPUT_STORAGE_FORMATS}"
)
# Raise error if unsupported prediction format is provided
if predictions_format not in constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS:
raise ValueError(
f"{predictions_format} is not an accepted prediction format "
f"type. Please choose from: {constants.BATCH_PREDICTION_OUTPUT_STORAGE_FORMATS}"
)
gca_bp_job = gca_bp_job_compat
gca_io = gca_io_compat
gca_machine_resources = gca_machine_resources_compat
select_version = compat.DEFAULT_VERSION
if generate_explanation:
gca_bp_job = gca_bp_job_v1beta1
gca_io = gca_io_v1beta1
gca_machine_resources = gca_machine_resources_v1beta1
select_version = compat.V1BETA1
gapic_batch_prediction_job = gca_bp_job.BatchPredictionJob()
# Required Fields
gapic_batch_prediction_job.display_name = job_display_name
gapic_batch_prediction_job.model = model_name
input_config = gca_bp_job.BatchPredictionJob.InputConfig()
output_config = gca_bp_job.BatchPredictionJob.OutputConfig()
if bigquery_source:
input_config.instances_format = "bigquery"
input_config.bigquery_source = gca_io.BigQuerySource()
input_config.bigquery_source.input_uri = bigquery_source
else:
input_config.instances_format = instances_format
input_config.gcs_source = gca_io.GcsSource(
uris=gcs_source if type(gcs_source) == list else [gcs_source]
)
if bigquery_destination_prefix:
output_config.predictions_format = "bigquery"
output_config.bigquery_destination = gca_io.BigQueryDestination()
bq_dest_prefix = bigquery_destination_prefix
if not bq_dest_prefix.startswith("bq://"):
bq_dest_prefix = f"bq://{bq_dest_prefix}"
output_config.bigquery_destination.output_uri = bq_dest_prefix
else:
output_config.predictions_format = predictions_format
output_config.gcs_destination = gca_io.GcsDestination(
output_uri_prefix=gcs_destination_prefix
)
gapic_batch_prediction_job.input_config = input_config
gapic_batch_prediction_job.output_config = output_config
# Optional Fields
gapic_batch_prediction_job.encryption_spec = initializer.global_config.get_encryption_spec(
encryption_spec_key_name=encryption_spec_key_name,
select_version=select_version,
)
if model_parameters:
gapic_batch_prediction_job.model_parameters = model_parameters
# Custom Compute
if machine_type:
machine_spec = gca_machine_resources.MachineSpec()
machine_spec.machine_type = machine_type
machine_spec.accelerator_type = accelerator_type
machine_spec.accelerator_count = accelerator_count
dedicated_resources = gca_machine_resources.BatchDedicatedResources()
dedicated_resources.machine_spec = machine_spec
dedicated_resources.starting_replica_count = starting_replica_count
dedicated_resources.max_replica_count = max_replica_count
gapic_batch_prediction_job.dedicated_resources = dedicated_resources
gapic_batch_prediction_job.manual_batch_tuning_parameters = None
# User Labels
gapic_batch_prediction_job.labels = labels
# Explanations
if generate_explanation:
gapic_batch_prediction_job.generate_explanation = generate_explanation
if explanation_metadata or explanation_parameters:
gapic_batch_prediction_job.explanation_spec = gca_explanation_v1beta1.ExplanationSpec(
metadata=explanation_metadata, parameters=explanation_parameters
)
# TODO (b/174502913): Support private feature once released
api_client = cls._instantiate_client(location=location, credentials=credentials)
return cls._create(
api_client=api_client,
parent=initializer.global_config.common_location_path(
project=project, location=location
),
batch_prediction_job=gapic_batch_prediction_job,
generate_explanation=generate_explanation,
project=project or initializer.global_config.project,
location=location or initializer.global_config.location,
credentials=credentials or initializer.global_config.credentials,
sync=sync,
)
@classmethod
@base.optional_sync()
def _create(
cls,
api_client: job_service_client.JobServiceClient,
parent: str,
batch_prediction_job: Union[
gca_bp_job_v1beta1.BatchPredictionJob, gca_bp_job_v1.BatchPredictionJob
],
generate_explanation: bool,
project: str,
location: str,
credentials: Optional[auth_credentials.Credentials],
sync: bool = True,
) -> "BatchPredictionJob":
"""Create a batch prediction job.
Args:
api_client (dataset_service_client.DatasetServiceClient):
Required. An instance of DatasetServiceClient with the correct api_endpoint
already set based on user's preferences.
batch_prediction_job (gca_bp_job.BatchPredictionJob):
Required. a batch prediction job proto for creating a batch prediction job on Vertex AI.
generate_explanation (bool):
Required. Generate explanation along with the batch prediction
results.
parent (str):
Required. Also known as common location path, that usually contains the
project and location that the user provided to the upstream method.
Example: "projects/my-prj/locations/us-central1"
project (str):
Required. Project to upload this model to. Overrides project set in
aiplatform.init.
location (str):
Required. Location to upload this model to. Overrides location set in
aiplatform.init.
credentials (Optional[auth_credentials.Credentials]):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
Returns:
(jobs.BatchPredictionJob):
Instantiated representation of the created batch prediction job.
Raises:
ValueError:
If no or multiple source or destinations are provided. Also, if
provided instances_format or predictions_format are not supported
by Vertex AI.
"""
# select v1beta1 if explain else use default v1
if generate_explanation:
api_client = api_client.select_version(compat.V1BETA1)
_LOGGER.log_create_with_lro(cls)
gca_batch_prediction_job = api_client.create_batch_prediction_job(
parent=parent, batch_prediction_job=batch_prediction_job
)
batch_prediction_job = cls(
batch_prediction_job_name=gca_batch_prediction_job.name,
project=project,
location=location,
credentials=credentials,
)
_LOGGER.log_create_complete(cls, batch_prediction_job._gca_resource, "bpj")
_LOGGER.info(
"View Batch Prediction Job:\n%s" % batch_prediction_job._dashboard_uri()
)
batch_prediction_job._block_until_complete()
return batch_prediction_job
def iter_outputs(
self, bq_max_results: Optional[int] = 100
) -> Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
"""Returns an Iterable object to traverse the output files, either a
list of GCS Blobs or a BigQuery RowIterator depending on the output
config set when the BatchPredictionJob was created.
Args:
bq_max_results: Optional[int] = 100
Limit on rows to retrieve from prediction table in BigQuery dataset.
Only used when retrieving predictions from a bigquery_destination_prefix.
Default is 100.
Returns:
Union[Iterable[storage.Blob], Iterable[bigquery.table.RowIterator]]:
Either a list of GCS Blob objects within the prediction output
directory or an iterable BigQuery RowIterator with predictions.
Raises:
RuntimeError:
If BatchPredictionJob is in a JobState other than SUCCEEDED,
since outputs cannot be retrieved until the Job has finished.
NotImplementedError:
If BatchPredictionJob succeeded and output_info does not have a
GCS or BQ output provided.
"""
self._assert_gca_resource_is_available()
if self.state != gca_job_state.JobState.JOB_STATE_SUCCEEDED:
raise RuntimeError(
f"Cannot read outputs until BatchPredictionJob has succeeded, "
f"current state: {self._gca_resource.state}"
)
output_info = self._gca_resource.output_info
# GCS Destination, return Blobs
if output_info.gcs_output_directory:
# Build a Storage Client using the same credentials as JobServiceClient
storage_client = storage.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)
gcs_bucket, gcs_prefix = utils.extract_bucket_and_prefix_from_gcs_path(
output_info.gcs_output_directory
)
blobs = storage_client.list_blobs(gcs_bucket, prefix=gcs_prefix)
return blobs
# BigQuery Destination, return RowIterator
elif output_info.bigquery_output_dataset:
# Format of `bigquery_output_dataset` from service is `bq://projectId.bqDatasetId`
bq_dataset = output_info.bigquery_output_dataset
bq_table = output_info.bigquery_output_table
if not bq_table:
raise RuntimeError(
"A BigQuery table with predictions was not found, this "
f"might be due to errors. Visit {self._dashboard_uri()} for details."
)
if bq_dataset.startswith("bq://"):
bq_dataset = bq_dataset[5:]
# # Split project ID and BQ dataset ID
_, bq_dataset_id = bq_dataset.split(".", 1)
# Build a BigQuery Client using the same credentials as JobServiceClient
bq_client = bigquery.Client(
project=self.project,
credentials=self.api_client._transport._credentials,
)
row_iterator = bq_client.list_rows(
table=f"{bq_dataset_id}.{bq_table}", max_results=bq_max_results
)
return row_iterator
# Unknown Destination type
else:
raise NotImplementedError(
f"Unsupported batch prediction output location, here are details"
f"on your prediction output:\n{output_info}"
)
class _RunnableJob(_Job):
"""ABC to interface job as a runnable training class."""
def __init__(
self,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
):
"""Initializes job with project, location, and api_client.
Args:
project(str): Project of the resource noun.
location(str): The location of the resource noun.
credentials(google.auth.crendentials.Crendentials): Optional custom
credentials to use when accessing interacting with resource noun.
"""
base.VertexAiResourceNounWithFutureManager.__init__(
self, project=project, location=location, credentials=credentials
)
self._parent = aiplatform.initializer.global_config.common_location_path(
project=project, location=location
)
@abc.abstractmethod
def run(self) -> None:
pass
@classmethod
def get(
cls,
resource_name: str,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> "_RunnableJob":
"""Get a Vertex AI Job for the given resource_name.
Args:
resource_name (str):
Required. A fully-qualified resource name or ID.
project (str):
Optional project to retrieve dataset from. If not set, project
set in aiplatform.init will be used.
location (str):
Optional location to retrieve dataset from. If not set, location
set in aiplatform.init will be used.
credentials (auth_credentials.Credentials):
Custom credentials to use to upload this model. Overrides
credentials set in aiplatform.init.
Returns:
A Vertex AI Job.
"""
self = cls._empty_constructor(
project=project,
location=location,
credentials=credentials,
resource_name=resource_name,
)
self._gca_resource = self._get_gca_resource(resource_name=resource_name)
return self
def wait_for_resource_creation(self) -> None:
"""Waits until resource has been created."""
self._wait_for_resource_creation()
class DataLabelingJob(_Job):
_resource_noun = "dataLabelingJobs"
_getter_method = "get_data_labeling_job"
_list_method = "list_data_labeling_jobs"
_cancel_method = "cancel_data_labeling_job"
_delete_method = "delete_data_labeling_job"
_job_type = "labeling-tasks"
pass
class CustomJob(_RunnableJob):
"""Vertex AI Custom Job."""
_resource_noun = "customJobs"
_getter_method = "get_custom_job"
_list_method = "list_custom_jobs"
_cancel_method = "cancel_custom_job"
_delete_method = "delete_custom_job"
_job_type = "training"
def __init__(
self,
display_name: str,
worker_pool_specs: Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]],
base_output_dir: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
labels: Optional[Dict[str, str]] = None,
encryption_spec_key_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
):
"""Cosntruct a Custom Job with Worker Pool Specs.
```
Example usage:
worker_pool_specs = [
{
"machine_spec": {
"machine_type": "n1-standard-4",
"accelerator_type": "NVIDIA_TESLA_K80",
"accelerator_count": 1,
},
"replica_count": 1,
"container_spec": {
"image_uri": container_image_uri,
"command": [],
"args": [],
},
}
]
my_job = aiplatform.CustomJob(
display_name='my_job',
worker_pool_specs=worker_pool_specs,
labels={'my_key': 'my_value'},
)
my_job.run()
```
For more information on configuring worker pool specs please visit:
https://cloud.google.com/ai-platform-unified/docs/training/create-custom-job
Args:
display_name (str):
Required. The user-defined name of the HyperparameterTuningJob.
The name can be up to 128 characters long and can be consist
of any UTF-8 characters.
worker_pool_specs (Union[List[Dict], List[aiplatform.gapic.WorkerPoolSpec]]):
Required. The spec of the worker pools including machine type and Docker image.
Can provided as a list of dictionaries or list of WorkerPoolSpec proto messages.
base_output_dir (str):
Optional. GCS output directory of job. If not provided a
timestamped directory in the staging directory will be used.
project (str):
Optional.Project to run the custom job in. Overrides project set in aiplatform.init.
location (str):
Optional.Location to run the custom job in. Overrides location set in aiplatform.init.
credentials (auth_credentials.Credentials):