Skip to content

Commit

Permalink
Merge pull request #6 from aelzeiny/add-executor-queues-to-fargate
Browse files Browse the repository at this point in the history
Add Executor Queues to Fargate, and improve APIs
  • Loading branch information
aelzeiny committed Jan 28, 2021
2 parents f530c46 + bae3ba2 commit 61daa98
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -4,7 +4,7 @@ python:
- "3.7"
- "3.8"
install:
- pip install apache-airflow boto3 pylint isort
- pip install apache-airflow boto3 pylint isort marshmallow
env:
- AIRFLOW__BATCH__REGION=us-west-1 AIRFLOW__BATCH__JOB_NAME=some-job-name AIRFLOW__BATCH__JOB_QUEUE=some-job-queue AIRFLOW__BATCH__JOB_DEFINITION=some-job-def AIRFLOW__ECS_FARGATE__REGION=us-west-1 AIRFLOW__ECS_FARGATE__CLUSTER=some-cluster AIRFLOW__ECS_FARGATE__CONTAINER_NAME=some-container-name AIRFLOW__ECS_FARGATE__TASK_DEFINITION=some-task-def AIRFLOW__ECS_FARGATE__LAUNCH_TYPE=FARGATE AIRFLOW__ECS_FARGATE__PLATFORM_VERSION=LATEST AIRFLOW__ECS_FARGATE__ASSIGN_PUBLIC_IP=DISABLED AIRFLOW__ECS_FARGATE__SECURITY_GROUPS=SG1,SG2 AIRFLOW__ECS_FARGATE__SUBNETS=SUB1,SUB2
script:
Expand Down
39 changes: 25 additions & 14 deletions airflow_aws_executors/batch_executor.py
Expand Up @@ -9,7 +9,7 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.module_loading import import_string
from airflow.utils.state import State
from marshmallow import Schema, fields, post_load
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load

CommandType = List[str]
TaskInstanceKeyType = Tuple[Any]
Expand Down Expand Up @@ -105,16 +105,17 @@ def _describe_tasks(self, job_ids) -> List[BatchJob]:
for i in range((len(job_ids) // max_batch_size) + 1):
batched_job_ids = job_ids[i * max_batch_size: (i + 1) * max_batch_size]
boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids)
describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
if describe_tasks_response.errors:
try:
describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
except ValidationError as err:
self.log.error('Batch DescribeJobs API Response: %s', boto_describe_tasks)
raise BatchError(
'DescribeJobs API call does not match expected JSON shape. '
'Are you sure that the correct version of Boto3 is installed? {}'.format(
describe_tasks_response.errors
err
)
)
all_jobs.extend(describe_tasks_response.data['jobs'])
all_jobs.extend(describe_tasks_response['jobs'])
return all_jobs

def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None):
Expand All @@ -135,16 +136,17 @@ def _submit_job(self, cmd: CommandType, exec_config: ExecutorConfigType) -> str:
submit_job_api['containerOverrides'].update(exec_config)
submit_job_api['containerOverrides']['command'] = cmd
boto_run_task = self.batch.submit_job(**submit_job_api)
submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task)
if submit_job_response.errors:
self.log.error('Batch SubmitJob Response: %s', submit_job_response)
try:
submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task)
except ValidationError as err:
self.log.error('Batch SubmitJob Response: %s', err)
raise BatchError(
'RunTask API call does not match expected JSON shape. '
'Are you sure that the correct version of Boto3 is installed? {}'.format(
submit_job_response.errors
err
)
)
return submit_job_response.data['job_id']
return submit_job_response['job_id']

def end(self, heartbeat_interval=10):
"""
Expand Down Expand Up @@ -213,29 +215,38 @@ def __len__(self):
class BatchSubmitJobResponseSchema(Schema):
"""API Response for SubmitJob"""
# The unique identifier for the job.
job_id = fields.String(load_from='jobId', required=True)
job_id = fields.String(data_key='jobId', required=True)

class Meta:
unknown = EXCLUDE


class BatchJobDetailSchema(Schema):
"""API Response for Describe Jobs"""
# The unique identifier for the job.
job_id = fields.String(load_from='jobId', required=True)
job_id = fields.String(data_key='jobId', required=True)
# The current status for the job: 'SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED'
status = fields.String(required=True)
# A short, human-readable string to provide additional details about the current status of the job.
status_reason = fields.String(load_from='statusReason')
status_reason = fields.String(data_key='statusReason')

@post_load
def make_job(self, data, **kwargs):
"""Overwrites marshmallow data property to return an instance of BatchJob instead of a dictionary"""
"""Overwrites marshmallow load() to return an instance of BatchJob instead of a dictionary"""
return BatchJob(**data)

class Meta:
unknown = EXCLUDE


class BatchDescribeJobsResponseSchema(Schema):
"""API Response for Describe Jobs"""
# The list of jobs
jobs = fields.List(fields.Nested(BatchJobDetailSchema), required=True)

class Meta:
unknown = EXCLUDE


class BatchError(Exception):
"""Thrown when something unexpected has occurred within the AWS Batch ecosystem"""
3 changes: 2 additions & 1 deletion airflow_aws_executors/conf.py
Expand Up @@ -25,7 +25,8 @@
from airflow.configuration import conf


def has_option(section, config_name):
def has_option(section, config_name) -> bool:
"""Returns True if configuration has a section and an option"""
if conf.has_option(section, config_name):
config_val = conf.get(section, config_name)
return config_val is not None and config_val != ''
Expand Down
102 changes: 66 additions & 36 deletions airflow_aws_executors/ecs_fargate_executor.py
Expand Up @@ -10,14 +10,14 @@
from airflow.executors.base_executor import BaseExecutor
from airflow.utils.module_loading import import_string
from airflow.utils.state import State
from marshmallow import Schema, fields, post_load
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load

CommandType = List[str]
TaskInstanceKeyType = Tuple[Any]
ExecutorConfigFunctionType = Callable[[CommandType], dict]
EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'executor_config'))
EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'queue', 'executor_config'))
ExecutorConfigType = Dict[str, Any]
EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'config'))
EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'queue', 'config'))


class EcsFargateTask:
Expand Down Expand Up @@ -147,17 +147,18 @@ def __describe_tasks(self, task_arns):
for i in range((len(task_arns) // self.DESCRIBE_TASKS_BATCH_SIZE) + 1):
batched_task_arns = task_arns[i * self.DESCRIBE_TASKS_BATCH_SIZE: (i + 1) * self.DESCRIBE_TASKS_BATCH_SIZE]
boto_describe_tasks = self.ecs.describe_tasks(tasks=batched_task_arns, cluster=self.cluster)
describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks)
if describe_tasks_response.errors:
try:
describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks)
except ValidationError as err:
self.log.error('ECS DescribeTask Response: %s', boto_describe_tasks)
raise EcsFargateError(
'DescribeTasks API call does not match expected JSON shape. '
'Are you sure that the correct version of Boto3 is installed? {}'.format(
describe_tasks_response.errors
err
)
)
all_task_descriptions['tasks'].extend(describe_tasks_response.data['tasks'])
all_task_descriptions['failures'].extend(describe_tasks_response.data['failures'])
all_task_descriptions['tasks'].extend(describe_tasks_response['tasks'])
all_task_descriptions['failures'].extend(describe_tasks_response['failures'])
return all_task_descriptions

def __handle_failed_task(self, task_arn: str, reason: str):
Expand All @@ -166,14 +167,14 @@ def __handle_failed_task(self, task_arn: str, reason: str):
ECS/Fargate Cloud. If an API failure occurs the task is simply rescheduled.
"""
task_key = self.active_workers.arn_to_key[task_arn]
task_cmd, exec_info = self.active_workers.info_by_key(task_key)
task_cmd, queue, exec_info = self.active_workers.info_by_key(task_key)
failure_count = self.active_workers.failure_count_by_key(task_key)
if failure_count < self.__class__.MAX_FAILURE_CHECKS:
self.log.warning('Task %s has failed due to %s. '
'Failure %s out of %s occurred on %s. Rescheduling.',
task_key, reason, failure_count, self.__class__.MAX_FAILURE_CHECKS, task_arn)
self.active_workers.increment_failure_count(task_key)
self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, exec_info))
self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, queue, exec_info))
else:
self.log.error('Task %s has failed a maximum of %s times. Marking as failed', task_key,
failure_count)
Expand All @@ -192,8 +193,8 @@ def attempt_task_runs(self):
failure_reasons = defaultdict(int)
for _ in range(queue_len):
ecs_task = self.pending_tasks.popleft()
task_key, cmd, exec_config = ecs_task
run_task_response = self.__run_task(cmd, exec_config)
task_key, cmd, queue, exec_config = ecs_task
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
if run_task_response['failures']:
for f in run_task_response['failures']:
failure_reasons[f['reason']] += 1
Expand All @@ -203,39 +204,53 @@ def attempt_task_runs(self):
raise EcsFargateError('No failures and no tasks provided in response. This should never happen.')
else:
task = run_task_response['tasks'][0]
self.active_workers.add_task(task, task_key, cmd, exec_config)
self.active_workers.add_task(task, task_key, queue, cmd, exec_config)
if failure_reasons:
self.log.debug('Pending tasks failed to launch for the following reasons: %s. Will retry later.',
dict(failure_reasons))

def __run_task(self, cmd: CommandType, exec_config: ExecutorConfigType):
def _run_task(self, task_id: TaskInstanceKeyType, cmd: CommandType, queue: str, exec_config: ExecutorConfigType):
"""
This function is the actual attempt to run a queued-up airflow task. Not to be confused with
execute_async() which inserts tasks into the queue.
The command and executor config will be placed in the container-override section of the JSON request, before
calling Boto3's "run_task" function.
"""
run_task_api = deepcopy(self.run_task_kwargs)
container_override = self.get_container(run_task_api['overrides']['containerOverrides'])
container_override['command'] = cmd
container_override.update(exec_config)
run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
boto_run_task = self.ecs.run_task(**run_task_api)
run_task_response = BotoRunTaskSchema().load(boto_run_task)
if run_task_response.errors:
self.log.error('ECS RunTask Response: %s', run_task_response)
try:
run_task_response = BotoRunTaskSchema().load(boto_run_task)
except ValidationError as err:
self.log.error('ECS RunTask Response: %s', err)
raise EcsFargateError(
'RunTask API call does not match expected JSON shape. '
'Are you sure that the correct version of Boto3 is installed? {}'.format(
run_task_response.errors
err
)
)
return run_task_response.data
return run_task_response

def _run_task_kwargs(self, task_id: TaskInstanceKeyType, cmd: CommandType,
queue: str, exec_config: ExecutorConfigType) -> dict:
"""
This modifies the standard kwargs to be specific to this task by overriding the airflow command and updating
the container overrides.
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
"""
run_task_api = deepcopy(self.run_task_kwargs)
container_override = self.get_container(run_task_api['overrides']['containerOverrides'])
container_override['command'] = cmd
container_override.update(exec_config)
return run_task_api

def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None):
"""
Save the task to be executed in the next sync using Boto3's RunTask API
Save the task to be executed in the next sync by inserting the commands into a queue.
"""
if executor_config and ('name' in executor_config or 'command' in executor_config):
raise ValueError('Executor Config should never override "name" or "command"')
self.pending_tasks.append(EcsFargateQueuedTask(key, command, executor_config or {}))
self.pending_tasks.append(EcsFargateQueuedTask(key, command, queue, executor_config or {}))

def end(self, heartbeat_interval=10):
"""
Expand Down Expand Up @@ -298,14 +313,14 @@ def __init__(self):
self.key_to_failure_counts: Dict[TaskInstanceKeyType, int] = defaultdict(int)
self.key_to_task_info: Dict[TaskInstanceKeyType, EcsFargateTaskInfo] = {}

def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, airflow_cmd: CommandType,
exec_config: ExecutorConfigType):
def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, queue: str,
airflow_cmd: CommandType, exec_config: ExecutorConfigType):
"""Adds a task to the collection"""
arn = task.task_arn
self.tasks[arn] = task
self.key_to_arn[airflow_task_key] = arn
self.arn_to_key[arn] = airflow_task_key
self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, exec_config)
self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, queue, exec_config)

def update_task(self, task: EcsFargateTask):
"""Updates the state of the given task based on task ARN"""
Expand Down Expand Up @@ -366,28 +381,34 @@ class BotoContainerSchema(Schema):
Botocore Serialization Object for ECS 'Container' shape.
Note that there are many more parameters, but the executor only needs the members listed below.
"""
exit_code = fields.Integer(load_from='exitCode')
last_status = fields.String(load_from='lastStatus')
exit_code = fields.Integer(data_key='exitCode')
last_status = fields.String(data_key='lastStatus')
name = fields.String(required=True)

class Meta:
unknown = EXCLUDE


class BotoTaskSchema(Schema):
"""
Botocore Serialization Object for ECS 'Task' shape.
Note that there are many more parameters, but the executor only needs the members listed below.
"""
task_arn = fields.String(load_from='taskArn', required=True)
last_status = fields.String(load_from='lastStatus', required=True)
desired_status = fields.String(load_from='desiredStatus', required=True)
task_arn = fields.String(data_key='taskArn', required=True)
last_status = fields.String(data_key='lastStatus', required=True)
desired_status = fields.String(data_key='desiredStatus', required=True)
containers = fields.List(fields.Nested(BotoContainerSchema), required=True)
started_at = fields.Field(load_from='startedAt')
stopped_reason = fields.String(load_from='stoppedReason')
started_at = fields.Field(data_key='startedAt')
stopped_reason = fields.String(data_key='stoppedReason')

@post_load
def make_task(self, data, **kwargs):
"""Overwrites marshmallow .data property to return an instance of EcsFargateTask instead of a dictionary"""
"""Overwrites marshmallow load() to return an instance of EcsFargateTask instead of a dictionary"""
return EcsFargateTask(**data)

class Meta:
unknown = EXCLUDE


class BotoFailureSchema(Schema):
"""
Expand All @@ -396,6 +417,9 @@ class BotoFailureSchema(Schema):
arn = fields.String()
reason = fields.String()

class Meta:
unknown = EXCLUDE


class BotoRunTaskSchema(Schema):
"""
Expand All @@ -404,6 +428,9 @@ class BotoRunTaskSchema(Schema):
tasks = fields.List(fields.Nested(BotoTaskSchema), required=True)
failures = fields.List(fields.Nested(BotoFailureSchema), required=True)

class Meta:
unknown = EXCLUDE


class BotoDescribeTasksSchema(Schema):
"""
Expand All @@ -412,6 +439,9 @@ class BotoDescribeTasksSchema(Schema):
tasks = fields.List(fields.Nested(BotoTaskSchema), required=True)
failures = fields.List(fields.Nested(BotoFailureSchema), required=True)

class Meta:
unknown = EXCLUDE


class EcsFargateError(Exception):
"""Thrown when something unexpected has occurred within the AWS ECS/Fargate ecosystem"""
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -12,7 +12,7 @@

setup(
name="airflow-aws-executors",
version="1.0.0",
version="1.1.0",
description=description,
long_description=long_description,
long_description_content_type="text/markdown",
Expand All @@ -29,5 +29,5 @@
],
packages=["airflow_aws_executors"],
include_package_data=True,
install_requires=["boto3", "apache-airflow>=1.10.5"]
install_requires=["boto3", "apache-airflow>=1.10.5", "marshmallow>=3"]
)

0 comments on commit 61daa98

Please sign in to comment.