Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise productVariant.stocks query. #15894

Merged
merged 1 commit into from May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -14,9 +14,7 @@
from ...core.mutations import BaseMutation
from ...core.types import BulkStockError, NonNullList
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..mutations.product.product_create import StockInput
from ..types import ProductVariant
Expand Down Expand Up @@ -69,9 +67,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):
manager.product_variant_back_in_stock, stock, webhooks=webhooks
)

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Down
Expand Up @@ -14,9 +14,7 @@
from ...core.types import NonNullList, StockError
from ...core.validators import validate_one_of_args_is_in_mutation
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..types import ProductVariant

Expand Down Expand Up @@ -85,9 +83,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):

stocks_to_delete.delete()

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Up @@ -15,9 +15,7 @@
from ...core.types import BulkStockError, NonNullList
from ...core.validators import validate_one_of_args_is_in_mutation
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..mutations.product.product_create import StockInput
from ..types import ProductVariant
Expand Down Expand Up @@ -81,9 +79,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):
manager = get_plugin_manager_promise(info.context).get()
cls.update_or_create_variant_stocks(variant, stocks, warehouses, manager)

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Down
36 changes: 36 additions & 0 deletions saleor/graphql/product/tests/queries/test_product_variant_query.py
Expand Up @@ -3,6 +3,7 @@
from measurement.measures import Weight

from .....core.units import WeightUnits
from .....warehouse import WarehouseClickAndCollectOption
from ....core.enums import WeightUnitsEnum
from ....tests.utils import assert_no_permission, get_graphql_content

Expand Down Expand Up @@ -167,6 +168,41 @@ def test_fetch_variant_no_stocks(
)


def test_fetch_variant_stocks_from_click_and_collect_warehouse(
staff_api_client,
product,
permission_manage_products,
channel_USD,
):
# given
query = QUERY_VARIANT
variant = product.variants.first()
stocks_count = variant.stocks.count()
warehouse = variant.stocks.first().warehouse

# remove the warehouse shipping zones and mark it as click and collect
# the stocks for this warehouse should be still returned
warehouse.shipping_zones.clear()
warehouse.click_and_collect_option = WarehouseClickAndCollectOption.LOCAL_STOCK
warehouse.save(update_fields=["click_and_collect_option"])

variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk)
variables = {"id": variant_id, "countryCode": "EU", "channel": channel_USD.slug}
staff_api_client.user.user_permissions.add(permission_manage_products)

# when
response = staff_api_client.post_graphql(query, variables)

# then
content = get_graphql_content(response)
data = content["data"]["productVariant"]
assert data["name"] == variant.name
assert data["created"] == variant.created_at.isoformat()

assert len(data["stocksByAddress"]) == stocks_count
assert not data["deprecatedStocksByCountry"]


QUERY_PRODUCT_VARIANT_CHANNEL_LISTING = """
query ProductVariantDetails($id: ID!, $channel: String) {
productVariant(id: $id, channel: $channel) {
Expand Down
11 changes: 8 additions & 3 deletions saleor/graphql/product/types/products.py
Expand Up @@ -117,6 +117,7 @@
from ...warehouse.dataloaders import (
AvailableQuantityByProductVariantIdCountryCodeAndChannelSlugLoader,
PreorderQuantityReservedByVariantChannelListingIdLoader,
StocksByProductVariantIdLoader,
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.types import Stock
Expand Down Expand Up @@ -436,9 +437,13 @@ def resolve_stocks(
):
if address is not None:
country_code = address.country
return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).load((root.node.id, country_code, root.channel_slug))
channle_slug = root.channel_slug
if channle_slug or country_code:
return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( # noqa: E501
info.context
).load((root.node.id, country_code, root.channel_slug))
else:
return StocksByProductVariantIdLoader(info.context).load(root.node.id)

@staticmethod
@load_site_callback
Expand Down
22 changes: 21 additions & 1 deletion saleor/graphql/shipping/dataloaders.py
@@ -1,6 +1,6 @@
from collections import defaultdict

from django.db.models import Exists, F, OuterRef
from django.db.models import Exists, F, OuterRef, Q

from ...channel.models import Channel
from ...shipping.models import (
Expand Down Expand Up @@ -220,3 +220,23 @@ def map_channels(channels):
.load_many({pk for pk, _ in channel_and_zone_is_pairs})
.then(map_channels)
)


class ShippingZonesByCountryLoader(DataLoader):
context_key = "shippingzones_by_country"

def batch_load(self, keys):
lookup = Q()
for key in keys:
lookup |= Q(countries__contains=key)
shipping_zones = ShippingZone.objects.using(
self.database_connection_name
).filter(lookup)

shipping_zones_by_country = defaultdict(list)
for shipping_zone in shipping_zones:
for country_code in keys:
if country_code in shipping_zone.countries:
shipping_zones_by_country[country_code].append(shipping_zone)

return [shipping_zones_by_country[key] for key in keys]