Skip to content

Commit

Permalink
initial: introduce BigQueryGetQueryResultsOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Apr 29, 2024
1 parent 25f901a commit d41aa71
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
7 changes: 7 additions & 0 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
UnknownJob,
)
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference
from google.cloud.bigquery.retry import DEFAULT_JOB_RETRY
from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
from google.cloud.exceptions import NotFound
from googleapiclient.discovery import Resource, build
Expand Down Expand Up @@ -2390,6 +2391,12 @@ def var_print(var_name):

return project_id, dataset_id, table_id

def get_query_results(self, job_id, project_id, location, max_results, retry: Retry = DEFAULT_JOB_RETRY):
job = self.get_job(job_id=job_id, project_id=project_id, location=location)
if not isinstance(job, QueryJob):
raise AirflowException(f"Job {job_id} is not a query job")
return job.result(max_results=max_results, job_retry=retry)

@property
def scopes(self) -> Sequence[str]:
"""
Expand Down
73 changes: 73 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2945,3 +2945,76 @@ def on_kill(self) -> None:
)
else:
self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id)


class BigQueryGetJobResultsOperator(GoogleCloudBaseOperator):
"""
Fetches results from a BigQuery query job given a job id.
:param project_id: Google Cloud Project where the job ran (templated)
:param job_id: The ID of the job.
The ID must contain only letters (a-z, A-Z), numbers (0-9), underscores (_), or
dashes (-). The maximum length is 1,024 characters. (templated)
:param max_results: The maximum number of records (rows) to be fetched
from the table. (templated)
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud.
:param location: The location used for the operation.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: Run operator in the deferrable mode
:param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job.
Defaults to 4 seconds.
"""

template_fields: Sequence[str] = (
"project_id",
"job_id",
"max_results",
"impersonation_chain",
)

def __init__(
self,
project_id: str = PROVIDE_PROJECT_ID,
job_id: str | None = None,
max_results: int | None = None,
as_dict: bool = False,
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
gcp_conn_id: str = "google_cloud_default",
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
**kwargs,
):
super().__init__(**kwargs)
self.project_id = project_id
self.job_id = job_id
self.max_results = max_results
self.location = location
self.impersonation_chain = impersonation_chain
self.gcp_conn_id = gcp_conn_id
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict

def execute(self, context: Context):
bq_hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
rows = bq_hook.get_query_results(
job_id=self.job_id,
project_id=self.project_id,
location=self.location,
max_results=self.max_results,
)

if self.as_dict:
table_data = [dict(row) for row in rows]
else:
table_data = [row.values() for row in rows]

return table_data

0 comments on commit d41aa71

Please sign in to comment.