Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProductsQueryset manager execute duplicated channel queries #15858

Merged
merged 3 commits into from Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions saleor/graphql/attribute/filters.py
Expand Up @@ -4,6 +4,7 @@

from ...attribute import AttributeInputType
from ...attribute.models import Attribute, AttributeValue
from ...channel.models import Channel
from ...permission.utils import has_one_of_permissions
from ...product import models
from ...product.models import ALL_PRODUCTS_PERMISSIONS
Expand Down Expand Up @@ -42,8 +43,12 @@ def filter_attributes_by_product_types(qs, field, value, requestor, channel_slug
if not value:
return qs

channel = None
if channel_slug is not None:
channel = Channel.objects.using(qs.db).filter(slug=str(channel_slug)).first()
limited_channel_access = False if channel_slug is None else True
product_qs = models.Product.objects.using(qs.db).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)

if field == "in_category":
Expand All @@ -57,7 +62,7 @@ def filter_attributes_by_product_types(qs, field, value, requestor, channel_slug
product_qs = product_qs.filter(category__in=tree)

if not has_one_of_permissions(requestor, ALL_PRODUCTS_PERMISSIONS):
product_qs = product_qs.annotate_visible_in_listings(channel_slug).exclude(
product_qs = product_qs.annotate_visible_in_listings(channel).exclude(
visible_in_listings=False
)

Expand Down
10 changes: 5 additions & 5 deletions saleor/graphql/order/utils.py
Expand Up @@ -209,7 +209,7 @@ def validate_product_is_published(
unpublished_product = (
Product.objects.using(database_connection_name)
.filter(variants__id__in=variant_ids)
.not_published(channel.slug)
.not_published(channel)
)
if unpublished_product.exists():
errors["lines"].append(
Expand Down Expand Up @@ -238,7 +238,7 @@ def validate_product_is_published_in_channel(
unpublished_product = list(
Product.objects.using(database_connection_name)
.filter(variants__id__in=variant_ids)
.not_published(channel.slug)
.not_published(channel)
)
if unpublished_product:
unpublished_variants = (
Expand Down Expand Up @@ -398,9 +398,9 @@ def prepare_insufficient_stock_order_validation_errors(exc):
"Insufficient product stock.",
code=OrderErrorCode.INSUFFICIENT_STOCK.value,
params={
"order_lines": [order_line_global_id]
if order_line_global_id
else [],
"order_lines": (
[order_line_global_id] if order_line_global_id else []
),
"warehouse": warehouse_global_id,
},
)
Expand Down
43 changes: 27 additions & 16 deletions saleor/graphql/product/resolvers.py
@@ -1,3 +1,5 @@
from typing import Optional

from django.db.models import Exists, OuterRef, Sum

from ...channel.models import Channel
Expand Down Expand Up @@ -66,11 +68,17 @@ def resolve_digital_contents(info: ResolveInfo):


def resolve_product(
info: ResolveInfo, id, slug, external_reference, channel_slug, requestor
info: ResolveInfo,
id,
slug,
external_reference,
channel: Optional[Channel],
limited_channel_access: bool,
requestor,
):
database_connection_name = get_database_connection_name(info.context)
qs = models.Product.objects.using(database_connection_name).visible_to_user(
requestor, channel_slug=channel_slug
requestor, channel, limited_channel_access
)
if id:
_type, id = from_global_id_or_error(id, "Product")
Expand All @@ -83,18 +91,17 @@ def resolve_product(

@traced_resolver
def resolve_products(
info: ResolveInfo, requestor, channel_slug=None
info: ResolveInfo,
requestor,
channel: Optional[Channel],
limited_channel_access: bool,
) -> ChannelQsContext:
connection_name = get_database_connection_name(info.context)
qs = models.Product.objects.using(connection_name).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)
if not has_one_of_permissions(requestor, ALL_PRODUCTS_PERMISSIONS):
if channel := (
Channel.objects.using(connection_name)
.filter(slug=str(channel_slug))
.first()
):
if channel:
product_channel_listings = (
models.ProductChannelListing.objects.using(connection_name)
.filter(channel_id=channel.id, visible_in_listings=True)
Expand All @@ -105,6 +112,7 @@ def resolve_products(
)
else:
qs = models.Product.objects.none()
channel_slug = channel.slug if channel else None
return ChannelQsContext(qs=qs, channel_slug=channel_slug)


Expand All @@ -129,21 +137,22 @@ def resolve_variant(
sku,
external_reference,
*,
channel_slug,
channel: Optional[Channel],
limited_channel_access: bool,
requestor,
requestor_has_access_to_all,
):
connection_name = get_database_connection_name(info.context)
visible_products = (
models.Product.objects.using(connection_name)
.visible_to_user(requestor, channel_slug)
.visible_to_user(requestor, channel, limited_channel_access)
.values_list("pk", flat=True)
)
qs = models.ProductVariant.objects.using(connection_name).filter(
product__id__in=visible_products
)
if not requestor_has_access_to_all:
qs = qs.available_in_channel(channel_slug)
qs = qs.available_in_channel(channel)
if id:
_, id = from_global_id_or_error(id, "ProductVariant")
return qs.filter(pk=id).first()
Expand All @@ -159,24 +168,26 @@ def resolve_product_variants(
requestor_has_access_to_all,
requestor,
ids=None,
channel_slug=None,
channel: Optional[Channel] = None,
limited_channel_access: bool = False,
) -> ChannelQsContext:
connection_name = get_database_connection_name(info.context)
visible_products = models.Product.objects.using(connection_name).visible_to_user(
requestor, channel_slug
requestor, channel, limited_channel_access
)
qs = models.ProductVariant.objects.using(connection_name).filter(
product__id__in=visible_products
)

channel_slug = channel.slug if channel else None
if not requestor_has_access_to_all:
visible_products = visible_products.annotate_visible_in_listings(
channel_slug
channel
).exclude(visible_in_listings=False)
qs = (
qs.using(connection_name)
.filter(product__in=visible_products)
.available_in_channel(channel_slug)
.available_in_channel(channel)
)
if ids:
db_ids = [
Expand Down
137 changes: 95 additions & 42 deletions saleor/graphql/product/schema.py
Expand Up @@ -5,6 +5,7 @@
from ...product.models import ALL_PRODUCTS_PERMISSIONS
from ...product.search import search_products
from ..channel import ChannelContext, ChannelQsContext
from ..channel.dataloaders import ChannelBySlugLoader
from ..channel.utils import get_default_channel_slug_or_graphql_error
from ..core import ResolveInfo
from ..core.connection import create_connection_slice, filter_connection_queryset
Expand Down Expand Up @@ -436,21 +437,35 @@ def resolve_product(
requestor, ALL_PRODUCTS_PERMISSIONS
)

limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)

product = resolve_product(
info,
id=id,
slug=slug,
external_reference=external_reference,
channel_slug=channel,
requestor=requestor,
)
def _resolve_product(channel_obj):
product = resolve_product(
info,
id=id,
slug=slug,
external_reference=external_reference,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor=requestor,
)

return ChannelContext(node=product, channel_slug=channel) if product else None
return (
ChannelContext(node=product, channel_slug=channel) if product else None
)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product)
)
else:
return _resolve_product(None)

@staticmethod
@traced_resolver
Expand All @@ -462,20 +477,32 @@ def resolve_products(_root, info: ResolveInfo, *, channel=None, **kwargs):
has_required_permissions = has_one_of_permissions(
requestor, ALL_PRODUCTS_PERMISSIONS
)
limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)
qs = resolve_products(info, requestor, channel_slug=channel)
if search:
qs = ChannelQsContext(
qs=search_products(qs.qs, search), channel_slug=channel

def _resolve_products(channel_obj):
qs = resolve_products(info, requestor, channel_obj, limited_channel_access)
if search:
qs = ChannelQsContext(
qs=search_products(qs.qs, search), channel_slug=channel
)
kwargs["channel"] = channel
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
kwargs["channel"] = channel
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(qs, info, kwargs, ProductCountableConnection)
return create_connection_slice(qs, info, kwargs, ProductCountableConnection)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_products)
)
else:
return _resolve_products(None)

@staticmethod
def resolve_product_type(_root, info: ResolveInfo, *, id):
Expand Down Expand Up @@ -509,22 +536,35 @@ def resolve_product_variant(
requestor, ALL_PRODUCTS_PERMISSIONS
)

limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)

variant = resolve_variant(
info,
id,
sku,
external_reference,
channel_slug=channel,
requestor=requestor,
requestor_has_access_to_all=has_required_permissions,
)
def _resolve_product_variant(channel_obj):
variant = resolve_variant(
info,
id,
sku,
external_reference,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor=requestor,
requestor_has_access_to_all=has_required_permissions,
)
return (
ChannelContext(node=variant, channel_slug=channel) if variant else None
)

return ChannelContext(node=variant, channel_slug=channel) if variant else None
if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product_variant)
)
else:
return _resolve_product_variant(None)

@staticmethod
def resolve_product_variants(
Expand All @@ -534,24 +574,37 @@ def resolve_product_variants(
has_required_permissions = has_one_of_permissions(
requestor, ALL_PRODUCTS_PERMISSIONS
)
limited_channel_access = False if channel is None else True
if channel is None and not has_required_permissions:
channel = get_default_channel_slug_or_graphql_error(
allow_replica=info.context.allow_replica
)
qs = resolve_product_variants(
info,
ids=ids,
channel_slug=channel,
requestor_has_access_to_all=has_required_permissions,
requestor=requestor,
)
kwargs["channel"] = qs.channel_slug
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(
qs, info, kwargs, ProductVariantCountableConnection
)

def _resolve_product_variants(channel_obj):
qs = resolve_product_variants(
info,
ids=ids,
channel=channel_obj,
limited_channel_access=limited_channel_access,
requestor_has_access_to_all=has_required_permissions,
requestor=requestor,
)
kwargs["channel"] = qs.channel_slug
qs = filter_connection_queryset(
qs, kwargs, allow_replica=info.context.allow_replica
)
return create_connection_slice(
qs, info, kwargs, ProductVariantCountableConnection
)

if channel:
return (
ChannelBySlugLoader(info.context)
.load(str(channel))
.then(_resolve_product_variants)
)
else:
return _resolve_product_variants(None)

@staticmethod
@traced_resolver
Expand Down
4 changes: 2 additions & 2 deletions saleor/graphql/product/tests/benchmark/test_product.py
Expand Up @@ -743,7 +743,7 @@ def test_products_for_federation_query_count(
],
}

with django_assert_num_queries(5):
with django_assert_num_queries(3):
response = api_client.post_graphql(query, variables)
content = get_graphql_content(response)
assert len(content["data"]["_entities"]) == 1
Expand All @@ -765,7 +765,7 @@ def test_products_for_federation_query_count(
],
}

with django_assert_num_queries(5):
with django_assert_num_queries(3):
response = api_client.post_graphql(query, variables)
content = get_graphql_content(response)
assert len(content["data"]["_entities"]) == 2
Expand Down