Skip to content

Commit

Permalink
Refs #35339 -- Updated Aggregate class to return consistent source ex…
Browse files Browse the repository at this point in the history
…pressions.

Refactored the filter and order_by expressions in the Aggregate class to
return a list of Expression (or None) values, ensuring that the list
item is always available and represents the filter expression.
For the PostgreSQL OrderableAggMixin, the returned list will always
include the filter and the order_by value as the last two elements.

Lastly, emtpy Q objects passed directly into aggregate objects using
Aggregate.filter in admin facets are filtered out when resolving the
expression to avoid errors in get_refs().

Thanks Simon Charette for the review.
  • Loading branch information
camuthig authored and nessita committed Apr 25, 2024
1 parent ec85524 commit 42b567a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
7 changes: 2 additions & 5 deletions django/contrib/postgres/aggregates/mixins.py
Expand Up @@ -17,13 +17,10 @@ def resolve_expression(self, *args, **kwargs):
return super().resolve_expression(*args, **kwargs)

def get_source_expressions(self):
if self.order_by is not None:
return super().get_source_expressions() + [self.order_by]
return super().get_source_expressions()
return super().get_source_expressions() + [self.order_by]

def set_source_expressions(self, exprs):
if isinstance(exprs[-1], OrderByList):
*exprs, self.order_by = exprs
*exprs, self.order_by = exprs
return super().set_source_expressions(exprs)

def as_sql(self, compiler, connection):
Expand Down
16 changes: 9 additions & 7 deletions django/db/models/aggregates.py
Expand Up @@ -50,21 +50,21 @@ def get_source_fields(self):

def get_source_expressions(self):
source_expressions = super().get_source_expressions()
if self.filter:
return source_expressions + [self.filter]
return source_expressions
return source_expressions + [self.filter]

def set_source_expressions(self, exprs):
self.filter = self.filter and exprs.pop()
*exprs, self.filter = exprs
return super().set_source_expressions(exprs)

def resolve_expression(
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
):
# Aggregates are not allowed in UPDATE queries, so ignore for_save
c = super().resolve_expression(query, allow_joins, reuse, summarize)
c.filter = c.filter and c.filter.resolve_expression(
query, allow_joins, reuse, summarize
c.filter = (
c.filter.resolve_expression(query, allow_joins, reuse, summarize)
if c.filter
else None
)
if summarize:
# Summarized aggregates cannot refer to summarized aggregates.
Expand Down Expand Up @@ -104,7 +104,9 @@ def resolve_expression(

@property
def default_alias(self):
expressions = self.get_source_expressions()
expressions = [
expr for expr in self.get_source_expressions() if expr is not None
]
if len(expressions) == 1 and hasattr(expressions[0], "name"):
return "%s__%s" % (expressions[0].name, self.name.lower())
raise TypeError("Complex expressions require an alias")
Expand Down
2 changes: 1 addition & 1 deletion tests/aggregation/tests.py
Expand Up @@ -1291,7 +1291,7 @@ class MyMax(Max):

def as_sql(self, compiler, connection):
copy = self.copy()
copy.set_source_expressions(copy.get_source_expressions()[0:1])
copy.set_source_expressions(copy.get_source_expressions()[0:1] + [None])
return super(MyMax, copy).as_sql(compiler, connection)

with self.assertRaisesMessage(TypeError, "Complex aggregates require an alias"):
Expand Down

0 comments on commit 42b567a

Please sign in to comment.