Skip to content

Commit

Permalink
Fix schema validation and add custom validators (#3220)
Browse files Browse the repository at this point in the history
* Fix schema validation and add custom validators

* Fix CI

* Simplify freezing processes
  • Loading branch information
adrien-berchet committed Jan 17, 2023
1 parent cf3be33 commit ddd94d0
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 12 deletions.
73 changes: 65 additions & 8 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,26 @@ def run(self):
$ luigi --module my_tasks MyTask --tags '{"role": "UNKNOWN_VALUE", "env": "staging"}'
Finally, the provided schema can be a custom validator:
.. code-block:: python
custom_validator = jsonschema.Draft4Validator(
schema={
"type": "object",
"patternProperties": {
".*": {"type": "string", "enum": ["web", "staging"]},
}
}
)
class MyTask(luigi.Task):
tags = luigi.DictParameter(schema=custom_validator)
def run(self):
logging.info("Find server with role: %s", self.tags['role'])
server = aws.ec2.find_my_resource(self.tags)
"""

def __init__(
Expand All @@ -1105,7 +1125,9 @@ def __init__(
"The 'jsonschema' package is not installed so the parameter can not be validated "
"even though a schema is given."
)
self.schema = schema
self.schema = None
else:
self.schema = schema
super().__init__(
*args,
**kwargs,
Expand All @@ -1115,10 +1137,14 @@ def normalize(self, value):
"""
Ensure that dictionary parameter is converted to a FrozenOrderedDict so it can be hashed.
"""
frozen_value = recursively_freeze(value)
if self.schema is not None:
jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema)
return frozen_value
unfrozen_value = recursively_unfreeze(value)
try:
self.schema.validate(unfrozen_value)
value = unfrozen_value # Validators may update the instance inplace
except AttributeError:
jsonschema.validate(instance=unfrozen_value, schema=self.schema)
return recursively_freeze(value)

def parse(self, source):
"""
Expand Down Expand Up @@ -1212,6 +1238,31 @@ def run(self):
$ luigi --module my_tasks MyTask --numbers '[]' # must have at least 1 element
$ luigi --module my_tasks MyTask --numbers '[-999, 999]' # elements must be in [0, 10]
Finally, the provided schema can be a custom validator:
.. code-block:: python
custom_validator = jsonschema.Draft4Validator(
schema={
"type": "array",
"items": {
"type": "number",
"minimum": 0,
"maximum": 10
},
"minItems": 1
}
)
class MyTask(luigi.Task):
grades = luigi.ListParameter(schema=custom_validator)
def run(self):
sum = 0
for element in self.grades:
sum += element
avg = sum / len(self.grades)
"""

def __init__(
Expand All @@ -1225,7 +1276,9 @@ def __init__(
"The 'jsonschema' package is not installed so the parameter can not be validated "
"even though a schema is given."
)
self.schema = schema
self.schema = None
else:
self.schema = schema
super().__init__(
*args,
**kwargs,
Expand All @@ -1238,10 +1291,14 @@ def normalize(self, x):
:param str x: the value to parse.
:return: the normalized (hashable/immutable) value.
"""
frozen_value = recursively_freeze(x)
if self.schema is not None:
jsonschema.validate(instance=recursively_unfreeze(frozen_value), schema=self.schema)
return frozen_value
unfrozen_value = recursively_unfreeze(x)
try:
self.schema.validate(unfrozen_value)
x = unfrozen_value # Validators may update the instance inplace
except AttributeError:
jsonschema.validate(instance=unfrozen_value, schema=self.schema)
return recursively_freeze(x)

def parse(self, x):
"""
Expand Down
20 changes: 20 additions & 0 deletions test/dict_parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from helpers import unittest, in_parse

Expand Down Expand Up @@ -113,6 +114,7 @@ def test_schema(self):
with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"):
b.normalize({"role": "UNKNOWN_VALUE", "env": "staging"})

# Check that warnings are properly emitted
with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False):
with pytest.warns(
UserWarning,
Expand All @@ -122,3 +124,21 @@ def test_schema(self):
)
):
luigi.ListParameter(schema={"type": "object"})

# Test with a custom validator
validator = Draft4Validator(
schema={
"type": "object",
"patternProperties": {
".*": {"type": "string", "enum": ["web", "staging"]},
},
}
)
c = luigi.DictParameter(schema=validator)
c.normalize({"role": "web", "env": "staging"})
with pytest.raises(ValidationError, match=r"'UNKNOWN_VALUE' is not one of \['web', 'staging'\]"):
c.normalize({"role": "UNKNOWN_VALUE", "env": "staging"})

# Test with frozen data
frozen_data = luigi.freezing.recursively_freeze({"role": "web", "env": "staging"})
c.normalize(frozen_data)
28 changes: 24 additions & 4 deletions test/list_parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from helpers import unittest, in_parse

Expand Down Expand Up @@ -76,10 +77,7 @@ def test_schema(self):
)

# Check that the default value is validated
with pytest.raises(
ValidationError,
match=r"'INVALID_ATTRIBUTE' is not of type 'number'",
):
with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'"):
a.normalize(["INVALID_ATTRIBUTE"])

# Check that empty list is not valid
Expand All @@ -100,6 +98,7 @@ def test_schema(self):
with pytest.raises(ValidationError, match="-999 is less than the minimum of 0"):
a.normalize(invalid_list_value)

# Check that warnings are properly emitted
with mock.patch('luigi.parameter._JSONSCHEMA_ENABLED', False):
with pytest.warns(
UserWarning,
Expand All @@ -109,3 +108,24 @@ def test_schema(self):
)
):
luigi.ListParameter(schema={"type": "array", "items": {"type": "number"}})

# Test with a custom validator
validator = Draft4Validator(
schema={
"type": "array",
"items": {
"type": "number",
"minimum": 0,
"maximum": 10,
},
"minItems": 1,
}
)
c = luigi.DictParameter(schema=validator)
c.normalize(valid_list)
with pytest.raises(ValidationError, match=r"'INVALID_ATTRIBUTE' is not of type 'number'",):
c.normalize(["INVALID_ATTRIBUTE"])

# Test with frozen data
frozen_data = luigi.freezing.recursively_freeze(valid_list)
c.normalize(frozen_data)
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ deps =
postgres: psycopg2<3.0
postgres: pg8000>=1.23.0
mysql-connector-python>=8.0.12
py35,py36: mysql-connector-python<8.0.32
gcloud: google-api-python-client>=1.6.6,<2.0
avro-python3
gcloud: google-auth==1.4.1
Expand Down

0 comments on commit ddd94d0

Please sign in to comment.