Skip to content

Commit

Permalink
Fixed #35339 -- Fixed PostgreSQL aggregate's filter and order_by para…
Browse files Browse the repository at this point in the history
…ms order.

Updated OrderableAggMixin.as_sql() to separate the order_by parameters
from the filter parameters. Previously, the parameters and SQL were
calculated by the Aggregate parent class, resulting in a mixture of
order_by and filter parameters.

Thanks Simon Charette for the review.
  • Loading branch information
camuthig authored and nessita committed Apr 25, 2024
1 parent 42b567a commit c8df2f9
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
27 changes: 21 additions & 6 deletions django/contrib/postgres/aggregates/mixins.py
@@ -1,3 +1,4 @@
from django.core.exceptions import FullResultSet
from django.db.models.expressions import OrderByList


Expand All @@ -24,9 +25,23 @@ def set_source_expressions(self, exprs):
return super().set_source_expressions(exprs)

def as_sql(self, compiler, connection):
if self.order_by is not None:
order_by_sql, order_by_params = compiler.compile(self.order_by)
else:
order_by_sql, order_by_params = "", ()
sql, sql_params = super().as_sql(compiler, connection, ordering=order_by_sql)
return sql, (*sql_params, *order_by_params)
*source_exprs, filtering_expr, ordering_expr = self.get_source_expressions()

order_by_sql = ""
order_by_params = []
if ordering_expr is not None:
order_by_sql, order_by_params = compiler.compile(ordering_expr)

filter_params = []
if filtering_expr is not None:
try:
_, filter_params = compiler.compile(filtering_expr)
except FullResultSet:
pass

source_params = []
for source_expr in source_exprs:
source_params += compiler.compile(source_expr)[1]

sql, _ = super().as_sql(compiler, connection, ordering=order_by_sql)
return sql, (*source_params, *order_by_params, *filter_params)
12 changes: 11 additions & 1 deletion tests/postgres_tests/test_aggregates.py
Expand Up @@ -12,7 +12,7 @@
Window,
)
from django.db.models.fields.json import KeyTextTransform, KeyTransform
from django.db.models.functions import Cast, Concat, Substr
from django.db.models.functions import Cast, Concat, LPad, Substr
from django.test import skipUnlessDBFeature
from django.test.utils import Approximate
from django.utils import timezone
Expand Down Expand Up @@ -238,6 +238,16 @@ def test_array_agg_jsonfield_ordering(self):
)
self.assertEqual(values, {"arrayagg": ["en", "pl"]})

def test_array_agg_filter_and_ordering_params(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg(
"char_field",
filter=Q(json_field__has_key="lang"),
ordering=LPad(Cast("integer_field", CharField()), 2, Value("0")),
)
)
self.assertEqual(values, {"arrayagg": ["Foo2", "Foo4"]})

def test_array_agg_filter(self):
values = AggregateTestModel.objects.aggregate(
arrayagg=ArrayAgg("integer_field", filter=Q(integer_field__gt=0)),
Expand Down

0 comments on commit c8df2f9

Please sign in to comment.