Skip to content

Commit

Permalink
Optionally detect parameter collisions - fix #2566
Browse files Browse the repository at this point in the history
  • Loading branch information
jeverling committed May 30, 2022
1 parent 62d1672 commit 401fb38
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 6 deletions.
10 changes: 10 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,16 @@ check_complete_on_run
missing.
Defaults to false.

prevent_parameter_collision
In complex pipelines especially when tasks are inherited, it can happen that
different tasks define parameters with the same name. Luigi would normally use
the same value for both parameter instances, which might not be desired.
When set to ``true``, luigi will check for parameter collisions and refuse to
run if a parameter is defined multiple times. Optionally, an allow-list of
parameters called ``collisions_to_ignore`` can be passed to ``inherits/requires``,
to ignore when checking for duplicate parameters.
Defaults to false.


[elasticsearch]
---------------
Expand Down
95 changes: 89 additions & 6 deletions luigi/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,10 @@ class TaskB(luigi.Task):

import datetime
import logging
from configparser import NoOptionError, NoSectionError

from luigi import task
from luigi import parameter
from luigi import parameter, task
from luigi.configuration import get_config


logger = logging.getLogger('luigi-interface')
Expand Down Expand Up @@ -277,18 +278,36 @@ def requires(self):
def run(self):
print self.n # this will be defined
# ...
inherits/requires decorator optionally takes an argument called
`collisions_to_ignore` with an iterable of parameters that are
allowed to overwrite parameters in upstream tasks.
In complex pipelines, it can happen that different tasks define parameters
with the same name.
If `prevent-parameter-collision` in the `[worker]` section of the config
is true, luigi will raise an exception in case of parameter conflicts -
unless the parameter is explicitly allowed in `collisions_to_ignore`.
"""

def __init__(self, *tasks_to_inherit, **kw_tasks_to_inherit):
def __init__(
self,
*tasks_to_inherit,
collisions_to_ignore=(),
**kw_tasks_to_inherit,
):
super(inherits, self).__init__()
if not tasks_to_inherit and not kw_tasks_to_inherit:
raise TypeError("tasks_to_inherit or kw_tasks_to_inherit must contain at least one task")
if tasks_to_inherit and kw_tasks_to_inherit:
raise TypeError("Only one of tasks_to_inherit or kw_tasks_to_inherit may be present")
self.tasks_to_inherit = tasks_to_inherit
self.kw_tasks_to_inherit = kw_tasks_to_inherit
self.collisions_to_ignore = collisions_to_ignore

def __call__(self, task_that_inherits):
# Check for parameter collisions and raise an exception if found
self._check_for_parameter_collisions(task_that_inherits)

# Get all parameter objects from each of the underlying tasks
task_iterator = self.tasks_to_inherit or self.kw_tasks_to_inherit.values()
for task_to_inherit in task_iterator:
Expand Down Expand Up @@ -323,6 +342,63 @@ def clone_parents(_self, **kwargs):

return task_that_inherits

def _check_for_parameter_collisions(self, task_that_inherits):
"""
Check that the parameters from the tasks_to_inherit don't
silently mask each other or by parameters from the inheriting
task.
An exception will be raised immediately when the first parameter
collision is encountered.
Collisions can be ignored by passing `collisions_to_ignore` with
an interable of allowed parameters to `inherits/requires`.
"""
# only check for parameter collisions when enabled in config
config = get_config()
try:
if config.getboolean("worker", "prevent_parameter_collision") is not True:
return
except (NoSectionError, NoOptionError, KeyError):
return

error_msg = (
'Parameter "{param}" in "{task}" is duplicated in "{another_task}" '
"(or an ancestor). Either rename one of the parameters or include "
'"{param}" in `collisions_to_ignore`.'
)

for task_to_inherit in self.tasks_to_inherit:
for param_name, _ in task_to_inherit.get_params():
# Check that the parameters from the inheriting task don't mask any
# parameters from the inherited tasks.
if (
hasattr(task_that_inherits, param_name)
and param_name not in self.collisions_to_ignore
):
raise ValueError(
error_msg.format(
param=param_name,
task=task_that_inherits.task_family,
another_task=task_to_inherit.task_family,
)
)
# Check that the parameters from an inherited task don't mask the
# parameters from another inherited task.
for another_task_to_inherit in self.tasks_to_inherit:
if (
hasattr(another_task_to_inherit, param_name)
and another_task_to_inherit is not task_to_inherit
and param_name not in self.collisions_to_ignore
):
raise ValueError(
error_msg.format(
param=param_name,
task=task_to_inherit.task_family,
another_task=another_task_to_inherit.task_family,
)
)


class requires:
"""
Expand All @@ -332,14 +408,21 @@ class requires:
"""

def __init__(self, *tasks_to_require, **kw_tasks_to_require):
def __init__(
self, *tasks_to_require, collisions_to_ignore=(), **kw_tasks_to_require
):
super(requires, self).__init__()

self.tasks_to_require = tasks_to_require
self.kw_tasks_to_require = kw_tasks_to_require
self.collisions_to_ignore = collisions_to_ignore

def __call__(self, task_that_requires):
task_that_requires = inherits(*self.tasks_to_require, **self.kw_tasks_to_require)(task_that_requires)
task_that_requires = inherits(
*self.tasks_to_require,
collisions_to_ignore=self.collisions_to_ignore,
**self.kw_tasks_to_require,
)(task_that_requires)

# Modify task_that_requires by adding requires method.
# If only one task is required, this single task is returned.
Expand Down Expand Up @@ -387,7 +470,7 @@ def run(_self):


def delegates(task_that_delegates):
""" Lets a task call methods on subtask(s).
"""Lets a task call methods on subtask(s).
The way this works is that the subtask is run as a part of the task, but
the task itself doesn't have to care about the requirements of the subtasks.
Expand Down
49 changes: 49 additions & 0 deletions test/parameter_collision_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest

import luigi
from luigi.util import requires

from helpers import with_config


class A(luigi.Task):
num = luigi.IntParameter()


class B(luigi.Task):
num = luigi.IntParameter()


class ParameterCollisionDetectionTest(unittest.TestCase):
@with_config({"worker": {"prevent_parameter_collision": "true"}})
def test_parameter_collision_with_inherited_task(self):
with self.assertRaises(ValueError):

@requires(A)
class T(luigi.Task):
num = luigi.IntParameter()

@with_config({"worker": {"prevent_parameter_collision": "true"}})
def test_parameter_collision_in_inheriting_tasks(self):
with self.assertRaises(ValueError):

@requires(A, B)
class T(luigi.Task):
pass

def test_no_parameter_collision_when_disabled_in_config(self):
@requires(A, B)
class T(luigi.Task):
pass

@with_config({"worker": {"prevent_parameter_collision": "true"}})
def test_parameter_collision_with_inherited_task_ignored_by_allowlist(self):
@requires(A, collisions_to_ignore=["num"])
class T(luigi.Task):
num = luigi.IntParameter()

@with_config({"worker": {"prevent_parameter_collision": "true"}})
def test_parameter_collision_in_inheriting_tasks_ignored_by_allowlist(self):
@requires(A, B, collisions_to_ignore=["num"])
class T(luigi.Task):
pass

0 comments on commit 401fb38

Please sign in to comment.