Skip to content

Commit

Permalink
Merge pull request #4 from dealertrack/enforce
Browse files Browse the repository at this point in the history
Fixed a bug in EnforceValidationFieldMixin
  • Loading branch information
miki725 committed Jul 15, 2015
2 parents 227b728 + a5ac540 commit a520f64
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 16 deletions.
12 changes: 8 additions & 4 deletions drf_braces/serializers/enforce_validation_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import inspect

from rest_framework import fields, serializers
from rest_framework.fields import empty

from ..utils import add_base_class_to_instance, get_class_name_with_new_suffix


class EnforceValidationFieldMixin(fields.Field):
class EnforceValidationFieldMixin(object):
"""
Custom DRF field mixin which allows to ignore validation error
if the field is not mandatory.
Expand All @@ -17,9 +18,9 @@ class EnforceValidationFieldMixin(fields.Field):
fields in that serializer are mandatory and must validate.
"""

def to_internal_value(self, data):
def run_validation(self, data=empty):
try:
return super(EnforceValidationFieldMixin, self).to_internal_value(data)
return super(EnforceValidationFieldMixin, self).run_validation(data)
except serializers.ValidationError:
must_validate_fields = getattr(self.parent, 'must_validate_fields', None)
field_name = getattr(self, 'field_name')
Expand All @@ -30,7 +31,10 @@ def to_internal_value(self, data):
if must_validate_fields is None or field_name in must_validate_fields:
raise
else:
raise fields.SkipField
raise fields.SkipField(
'This field "{}" is being skipped as per enforce validation logic.'
''.format(field_name)
)


def _create_enforce_validation_serializer(serializer, strict_mode_by_default=True):
Expand Down
5 changes: 4 additions & 1 deletion drf_braces/tests/fields/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import unittest

from drf_braces.fields import _fields
from drf_braces.fields.mixins import AllowBlankNullFieldMixin, EmptyStringFieldMixin
from drf_braces.fields.mixins import (
AllowBlankNullFieldMixin,
EmptyStringFieldMixin,
)


class TestFields(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,36 +35,36 @@ class TestEnforceValidationFieldMixin(unittest.TestCase):
class Field(EnforceValidationFieldMixin, fields.IntegerField):
pass

def test_field_from_native_must_validate(self):
def test_run_validation_must_validate(self):
field = self.Field()
field.field_name = 'field'
field.parent = mock.MagicMock(must_validate_fields=None)

self.assertEqual(field.to_internal_value('5'), 5)
self.assertEqual(field.run_validation('5'), 5)

def test_field_from_native_must_validate_all(self):
def test_run_validation_must_validate_all(self):
field = self.Field()
field.field_name = 'field'
field.parent = mock.MagicMock(must_validate_fields=None)

with self.assertRaises(serializers.ValidationError):
field.to_internal_value('hello')
field.run_validation('hello')

def test_field_from_native_must_validate_invalid(self):
def test_run_validation_must_validate_invalid(self):
field = self.Field()
field.field_name = 'field'
field.parent = mock.MagicMock(must_validate_fields=['field'])

with self.assertRaises(serializers.ValidationError):
field.to_internal_value('hello')
field.run_validation('hello')

def test_field_from_native_must_validate_ignore(self):
def test_run_validation_must_validate_ignore(self):
field = self.Field()
field.field_name = 'field'
field.parent = mock.MagicMock(must_validate_fields=[])

with self.assertRaises(serializers.SkipField):
field.to_internal_value('hello')
field.run_validation('hello')


class TestUtils(unittest.TestCase):
Expand Down
29 changes: 26 additions & 3 deletions drf_braces/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import print_function, unicode_literals

import unittest

from rest_framework import fields

from drf_braces.utils import (
get_class_name_with_new_suffix,
find_class_args,
find_function_args,
get_attr_from_base_classes,
get_class_name_with_new_suffix,
)
from rest_framework import fields


class TestUtils(unittest.TestCase):
Expand Down Expand Up @@ -48,3 +50,24 @@ def test_get_attr_from_base_classes(self):
get_attr_from_base_classes(
(Parent,), {'fields': 'mushrooms'}, 'catchmeifyoucan'
)

def test_find_function_args(self):
def foo(a, b, c):
pass

self.assertListEqual(find_function_args(foo), ['a', 'b', 'c'])

def test_find_function_args_invalid(self):
self.assertListEqual(find_function_args(None), [])

def test_find_class_args(self):
class Bar(object):
def __init__(self, a, b):
pass

class Foo(Bar):
def __init__(self, c, d):
super(Foo, self).__init__(None, None)
pass

self.assertSetEqual(set(find_class_args(Foo)), {'a', 'b', 'c', 'd'})

0 comments on commit a520f64

Please sign in to comment.