Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add group validation & filtering #1167

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions django_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ def __init__(self, field_name=None, lookup_expr='exact', *, label=None,
self.creation_counter = Filter.creation_counter
Filter.creation_counter += 1

# Set during parent FilterSet class creation
self.parent = None
self.model = None

# TODO: remove assertion in 2.1
assert not isinstance(self.lookup_expr, (type(None), list)), \
"The `lookup_expr` argument no longer accepts `None` or a list of " \
Expand Down Expand Up @@ -117,7 +121,7 @@ def fset(self, value):

def label():
def fget(self):
if self._label is None and hasattr(self, 'model'):
if self._label is None and self.model is not None:
self._label = label_for_filter(
self.model, self.field_name, self.lookup_expr, self.exclude
)
Expand Down Expand Up @@ -776,7 +780,7 @@ def method(self):
return instance.method

# otherwise, method is the name of a method on the parent FilterSet.
assert hasattr(instance, 'parent'), \
assert instance.parent is not None, \
"Filter '%s' must have a parent FilterSet to find '.%s()'" % \
(instance.field_name, instance.method)

Expand Down
33 changes: 31 additions & 2 deletions django_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, options=None):
self.model = getattr(options, 'model', None)
self.fields = getattr(options, 'fields', None)
self.exclude = getattr(options, 'exclude', None)
self.groups = getattr(options, 'groups', [])

self.filter_overrides = getattr(options, 'filter_overrides', {})

Expand Down Expand Up @@ -193,12 +194,18 @@ def __init__(self, data=None, queryset=None, *, request=None, prefix=None):
self.form_prefix = prefix

self.filters = copy.deepcopy(self.base_filters)
self.groups = copy.deepcopy(self._meta.groups)

# propagate the model and filterset to the filters
for filter_ in self.filters.values():
filter_.model = model
filter_.parent = self

# propagate the model and filterset to the groups
for group in self.groups:
group.model = model
group.parent = self

def is_valid(self):
"""
Return True if the underlying form has no errors, or False otherwise.
Expand All @@ -220,11 +227,21 @@ def filter_queryset(self, queryset):
This method should be overridden if additional filtering needs to be
applied to the queryset before it is cached.
"""
for name, value in self.form.cleaned_data.items():
cleaned_data = self.form.cleaned_data.copy()

# Extract the grouped data from the rest of the `cleaned_data`. This
# ensures that the original filter methods aren't called in addition
# to the group filter methods.
for group in self.groups:
group_data, cleaned_data = group.extract_data(cleaned_data)
queryset = group.filter(queryset, **group_data)

for name, value in cleaned_data.items():
queryset = self.filters[name].filter(queryset, value)
assert isinstance(queryset, models.QuerySet), \
"Expected '%s.%s' to return a QuerySet, but got a %s instead." \
% (type(self).__name__, name, type(queryset).__name__)

return queryset

@property
Expand All @@ -245,12 +262,24 @@ def get_form_class(self):
This method should be overridden if the form class needs to be
customized relative to the filterset instance.
"""
class FilterSetForm(forms.Form):
def clean(form):
cleaned_data = super().clean()

for group in self.groups:
# Ignore the modified `cleaned_data`, as we only want to remove
# data on error, which is accomplished through `form.add_error`.
group_data, _ = group.extract_data(cleaned_data)
group.validate(form, **group_data)

return cleaned_data

fields = OrderedDict([
(name, filter_.field)
for name, filter_ in self.filters.items()])

return type(str('%sForm' % self.__class__.__name__),
(self._meta.form,), fields)
(FilterSetForm, self._meta.form,), fields)

@property
def form(self):
Expand Down
289 changes: 289 additions & 0 deletions django_filters/groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import functools
import operator
from abc import abstractmethod

from django.core.exceptions import ValidationError
from django.db.models import Q
from django.db.models.constants import LOOKUP_SEP
from django.utils.translation import ugettext as _

from .constants import EMPTY_VALUES

__all__ = [
'BaseFilterGroup', 'ExclusiveGroup', 'RequiredGroup',
'CombinedGroup', 'CombinedRequiredGroup',
]


class BaseFilterGroup:
"""Base class for validating & filtering query parameters as a group."""

def __init__(self, filters):
if not filters or len(filters) < 2:
msg = "A filter group must contain at least two members."
raise ValueError(msg)
if len(set(filters)) != len(filters):
msg = "A filter group must not contain duplicate members."
raise ValueError(msg)
self.filters = filters

# Set during parent FilterSet initialization
self.parent = None
self.model = None

@abstractmethod
def validate(self, form, **data):
"""Validate the subset of cleaned data provided by the form.

Args:
form: The underlying ``Form`` instance used to validate the query
params. A ``FilterGroup`` should add errors to this form using
``form.add_error(<field name>, <error message>)``.
**data: The subset of a form's ``cleaned_data`` for the filters in
the group.
"""
raise NotImplementedError

@abstractmethod
def filter(self, qs, **data):
"""Filter the result queryset with the subset of cleaned data.

Args:
qs: The ``QuerySet`` instance to filter.
**data: The subset of a form's ``cleaned_data`` for the filters in
the group.

Returns:
The filtered queryset instance.
"""
raise NotImplementedError

def format_labels(self, filters):
"""Return a formatted string of labels for the given filter names.

This inspects the filter labels from the ``self.parent`` FilterSet,
and combines them into a formatted string. This string can then be
used in validation error messages.

Args:
filters: A list of filter names.

Returns:
The formatted string of labels. For example, if filters 'a', 'b',
and 'c' have corresponding labels 'Filter A', 'Filter B', and
'Filter C', then...

>>> group.format_labels(['a', 'b'])
"'Filter A' and 'Filter B'"

>>> group.format_labels(['a', 'b', 'c'])
"'Filter A', 'Filter B', and 'Filter C'"
"""
labels = [self.parent.filters[name].label for name in filters]

if len(labels) == 2:
return "'%s' and '%s'" % tuple(labels)

# e.g., joined = "'A', 'B', and 'C'"
joined = ', '.join("'%s'" % l for l in labels)
joined = ', and '.join(joined.rsplit(', ', 1))
return joined

def extract_data(self, cleaned_data):
"""Extract the subset of cleaned data for the filters in the group.

Note that this is an internal method called by the ``FilterSet``.
Subclasses should not need to call or override this method.

Args:
cleaned_data: The underlying form's ``cleaned_data`` dict.

Returns:
A two-tuple containing the dict subset of data for the filters in
the group, and the remainder of the original ``cleaned_data`` dict.
"""
# Create a copy so as to not modify the original data dict.
data = cleaned_data.copy()

return {
name: data.pop(name)
for name in self.filters
if name in data
}, data

def _filter_data(self, data):
# Helper for checking the extacted data.
# - Sanity check that correct data has been provided by the filterset.
# - Remove empty values that would normally be skipped by the
# ``Filter.filter`` method.
assert set(data).issubset(set(self.filters)), (
"The `data` must be a subset of the group's `.filters`.")
return {k: v for k, v in data.items() if v not in EMPTY_VALUES}


class ExclusiveGroup(BaseFilterGroup):
"""A group of mutually exclusive filters.

If any filter in the group is included in the request data, then all other
filters in the group **must not** be present in the request data.

Attributes:
filters: The set of filter names to operate on.
"""

def validate(self, form, **data):
data = self._filter_data(data)

if len(data) > 1:
err = ValidationError(
_('%(filters)s are mutually exclusive.'),
params={'filters': self.format_labels(self.filters)})

# The error message should include all filters (A, B, and C),
# but the error should only be raised for the given filters.
for param in data:
form.add_error(param, err)

def filter(self, qs, **data):
data = self._filter_data(data)
if not data:
return qs

assert len(data) <= 1, "The `data` should consist of only one element."

param, value = next(iter(data.items()))

return self.parent.filters[param].filter(qs, value)


class RequiredGroup(BaseFilterGroup):
"""A group of mutually required filters.

If any filter in the group is included in the request data, then all other
filters in the group **must** be present in the request data. Filtering is
still performed by the individual filters and is not combined via ``Q``
objects. To use ``Q`` objects instead (e.g., for OR-based filtering), use
the ``CombinedRequiredGroup``.

Attributes:
filters: The set of filter names to operate on.
"""

def validate(self, form, **data):
data = self._filter_data(data)

if data and set(data) != set(self.filters):
err = ValidationError(
_('%(filters)s are mutually required.'),
params={'filters': self.format_labels(self.filters)})

# Unlike ``ExclusiveGroup``, the error should be raised for all
# filters since the missing filters are part of the error state.
for param in self.filters:
form.add_error(param, err)

def filter(self, qs, **data):
data = self._filter_data(data)

assert not data or len(data) == len(self.filters), (
"The `data` should contain all filters.")

# Filter by chaining the original filter method calls.
for param, value in data.items():
qs = self.parent.filters[param].filter(qs, value)

return qs


class CombinedGroup(BaseFilterGroup):
"""A group of filters that result in a combined query (a ``Q`` object).

This implementation combines ``Q`` objects *instead* of chaining
``.filter()`` calls. The ``Q`` objects are generated from the filter's
``field_name``, ``lookup_expr``, and ``exclude`` attributes, and the
resulting queryset will call ``.distinct()`` if set on any of the filters.

In short, instead of generating the following filter call:

.. code-block:: python

qs.filter(a=1).filter(b=2)

This group would generate a call like:

.. code-block:: python

qs.filter(Q(a=1) & Q(b=2))

This is useful for enabling OR filtering, as well as combining filters that
span multi-valued relationships (`more info`__).

__ https://docs.djangoproject.com/en/stable/topics/db/queries/#spanning-multi-valued-relationships

Attributes:
filters: The set of filter names to operate on.
combine: A function that combines two ``Q`` objects. Defaults to
``operator.and_``. For OR operations, use ``operator.or_``.
"""

def __init__(self, filters, combine=operator.and_):
super().__init__(filters)
self.combine = combine

def validate(self, form, **data):
# CombinedGroup has no specific validation rules.
self._filter_data(data)

def filter(self, qs, **data):
data = self._filter_data(data)

if not data:
return qs

# Filter by combining the set of constructed Q objects.
qs = qs.filter(functools.reduce(self.combine, [
self.build_q_object(param, value)
for param, value in data.items()]))

# If any filter is marked as distinct, the qs should also be distinct.
if any(self.parent.filters[param].distinct for param in data):
qs = qs.distinct()

return qs

def build_q_object(self, filter_name, value):
"""Build a ``Q`` object for the given filter name and value.

The ``Q`` objects are generated from the filter's ``field_name``,
``lookup_expr``, and ``exclude`` attributes.

Args:
filter_name: The name of the filter to base the ``Q`` object off of.
value: The value to filter within the ``Q`` object.

Returns:
A ``Q`` object that is reprentative of the filter and value.
"""
f = self.parent.filters[filter_name]
q = Q(**{LOOKUP_SEP.join([f.field_name, f.lookup_expr]): value})
if f.exclude:
q = ~q

return q


class CombinedRequiredGroup(CombinedGroup, RequiredGroup):
"""A group of mutually required filters that result in a combined query.

This combines the validation logic of a ``RequiredGroup`` with the
filtering logic of a ``CombinedGroup``.

Attributes:
filters: The set of filter names to operate on.
combine: A function that combines two ``Q`` objects. Defaults to
``operator.and_``. For OR operations, use ``operator.or_``.
"""

def validate(self, form, **data):
# Use the validation provided by RequiredGroup
super(CombinedGroup, self).validate(form, **data)