Skip to content

Commit

Permalink
Fix deferrable mode for BeamRunJavaPipelineOperator (#39371)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed May 14, 2024
1 parent 1e4663f commit 1489cf7
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 28 deletions.
21 changes: 3 additions & 18 deletions airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def execute(self, context: Context):
if not self.beam_hook:
raise AirflowException("Beam hook is not defined.")
if self.deferrable:
asyncio.run(self.execute_async(context))
self.execute_async(context)
else:
return self.execute_sync(context)

Expand Down Expand Up @@ -605,23 +605,7 @@ def execute_sync(self, context: Context):
process_line_callback=self.process_line_callback,
)

async def execute_async(self, context: Context):
# Creating a new event loop to manage I/O operations asynchronously
loop = asyncio.get_event_loop()
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
self.jar = tmp_gcs_file.name

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
Expand Down Expand Up @@ -657,6 +641,7 @@ async def execute_async(self, context: Context):
job_class=self.job_class,
runner=self.runner,
check_if_running=self.dataflow_config.check_if_running == CheckJobRunning.WaitForRun,
gcp_conn_id=self.gcp_conn_id,
),
method_name="execute_complete",
)
Expand Down
22 changes: 20 additions & 2 deletions airflow/providers/apache/beam/triggers/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator, Sequence
import contextlib
from typing import IO, Any, AsyncIterator, Sequence

from deprecated import deprecated
from google.cloud.dataflow_v1beta3 import ListJobsRequest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand Down Expand Up @@ -166,7 +168,7 @@ def __init__(
project_id: str | None = None,
location: str | None = None,
job_name: str | None = None,
gcp_conn_id: str | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_sleep: int = 10,
cancel_timeout: int | None = None,
Expand Down Expand Up @@ -233,6 +235,22 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
if is_running:
await asyncio.sleep(self.poll_sleep)
try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
self.jar = tmp_gcs_file.name

return_code = await hook.start_java_pipeline_async(
variables=self.variables, jar=self.jar, job_class=self.job_class
)
Expand Down
10 changes: 2 additions & 8 deletions tests/providers/apache/beam/operators/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,24 +1013,20 @@ def test_async_execute_should_execute_successfully(self, gcs_hook, beam_hook_moc
), "Trigger is not a BeamPJavaPipelineTrigger"

@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_async_execute_direct_runner(self, gcs_hook, beam_hook_mock):
def test_async_execute_direct_runner(self, beam_hook_mock):
"""
Test BeamHook is created and the right args are passed to
start_java_pipeline when executing direct runner.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
op = BeamRunJavaPipelineOperator(**self.default_op_kwargs)
with pytest.raises(TaskDeferred):
op.execute(context=mock.MagicMock())
beam_hook_mock.assert_called_once_with(runner=DEFAULT_RUNNER)
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)

@mock.patch(BEAM_OPERATOR_PATH.format("DataflowJobLink.persist"))
@mock.patch(BEAM_OPERATOR_PATH.format("BeamHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("DataflowHook"))
@mock.patch(BEAM_OPERATOR_PATH.format("GCSHook"))
def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
def test_exec_dataflow_runner(self, dataflow_hook_mock, beam_hook_mock, persist_link_mock):
"""
Test DataflowHook is created and the right args are passed to
start_java_pipeline when executing Dataflow runner.
Expand All @@ -1039,7 +1035,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
op = BeamRunJavaPipelineOperator(
runner="DataflowRunner", dataflow_config=dataflow_config, **self.default_op_kwargs
)
gcs_provide_file = gcs_hook.return_value.provide_file
magic_mock = mock.MagicMock()
with pytest.raises(TaskDeferred):
op.execute(context=magic_mock)
Expand All @@ -1062,7 +1057,6 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"region": "us-central1",
"impersonate_service_account": TEST_IMPERSONATION_ACCOUNT,
}
gcs_provide_file.assert_called_once_with(object_url=JAR_FILE)
persist_link_mock.assert_called_once_with(
op,
magic_mock,
Expand Down
13 changes: 13 additions & 0 deletions tests/providers/apache/beam/triggers/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TEST_PY_PACKAGES = False
TEST_RUNNER = "DirectRunner"
TEST_JAR_FILE = "example.jar"
TEST_GCS_JAR_FILE = "gs://my-bucket/example/test.jar"
TEST_JOB_CLASS = "TestClass"
TEST_CHECK_IF_RUNNING = False
TEST_JOB_NAME = "test_job_name"
Expand Down Expand Up @@ -215,3 +216,15 @@ async def test_beam_trigger_exception_list_jobs_should_execute_successfully(
generator = java_trigger.run()
actual = await generator.asend(None)
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, java_trigger):
"""
Test that BeamJavaPipelineTrigger downloads GCS provide file correct.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
java_trigger.jar = TEST_GCS_JAR_FILE
generator = java_trigger.run()
await generator.asend(None)
gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)

0 comments on commit 1489cf7

Please sign in to comment.