Skip to content

Commit

Permalink
Optimise variant.stocks query.
Browse files Browse the repository at this point in the history
  • Loading branch information
zedzior committed Apr 26, 2024
1 parent 8422c2e commit 8f80dfa
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 89 deletions.
Expand Up @@ -14,9 +14,7 @@
from ...core.mutations import BaseMutation
from ...core.types import BulkStockError, NonNullList
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..mutations.product.product_create import StockInput
from ..types import ProductVariant
Expand Down Expand Up @@ -69,9 +67,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):
manager.product_variant_back_in_stock, stock, webhooks=webhooks
)

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Down
Expand Up @@ -14,9 +14,7 @@
from ...core.types import NonNullList, StockError
from ...core.validators import validate_one_of_args_is_in_mutation
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..types import ProductVariant

Expand Down Expand Up @@ -85,9 +83,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):

stocks_to_delete.delete()

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Up @@ -15,9 +15,7 @@
from ...core.types import BulkStockError, NonNullList
from ...core.validators import validate_one_of_args_is_in_mutation
from ...plugins.dataloaders import get_plugin_manager_promise
from ...warehouse.dataloaders import (
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.dataloaders import StocksByProductVariantIdLoader
from ...warehouse.types import Warehouse
from ..mutations.product.product_create import StockInput
from ..types import ProductVariant
Expand Down Expand Up @@ -81,9 +79,7 @@ def perform_mutation(cls, _root, info: ResolveInfo, /, **data):
manager = get_plugin_manager_promise(info.context).get()
cls.update_or_create_variant_stocks(variant, stocks, warehouses, manager)

StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).clear((variant.id, None, None))
StocksByProductVariantIdLoader(info.context).clear(variant.id)

variant = ChannelContext(node=variant, channel_slug=None)
return cls(product_variant=variant)
Expand Down
36 changes: 36 additions & 0 deletions saleor/graphql/product/tests/queries/test_product_variant_query.py
Expand Up @@ -3,6 +3,7 @@
from measurement.measures import Weight

from .....core.units import WeightUnits
from .....warehouse import WarehouseClickAndCollectOption
from ....core.enums import WeightUnitsEnum
from ....tests.utils import assert_no_permission, get_graphql_content

Expand Down Expand Up @@ -167,6 +168,41 @@ def test_fetch_variant_no_stocks(
)


def test_fetch_variant_stocks_from_click_and_collect_warehouse(
staff_api_client,
product,
permission_manage_products,
channel_USD,
):
# given
query = QUERY_VARIANT
variant = product.variants.first()
stocks_count = variant.stocks.count()
warehouse = variant.stocks.first().warehouse

# remove the warehouse shipping zones and mark it as click and collect
# the stocks for this warehouse should be still returned
warehouse.shipping_zones.clear()
warehouse.click_and_collect_option = WarehouseClickAndCollectOption.LOCAL_STOCK
warehouse.save(update_fields=["click_and_collect_option"])

variant_id = graphene.Node.to_global_id("ProductVariant", variant.pk)
variables = {"id": variant_id, "countryCode": "EU", "channel": channel_USD.slug}
staff_api_client.user.user_permissions.add(permission_manage_products)

# when
response = staff_api_client.post_graphql(query, variables)

# then
content = get_graphql_content(response)
data = content["data"]["productVariant"]
assert data["name"] == variant.name
assert data["created"] == variant.created_at.isoformat()

assert len(data["stocksByAddress"]) == stocks_count
assert not data["deprecatedStocksByCountry"]


QUERY_PRODUCT_VARIANT_CHANNEL_LISTING = """
query ProductVariantDetails($id: ID!, $channel: String) {
productVariant(id: $id, channel: $channel) {
Expand Down
13 changes: 10 additions & 3 deletions saleor/graphql/product/types/products.py
Expand Up @@ -117,6 +117,7 @@
from ...warehouse.dataloaders import (
AvailableQuantityByProductVariantIdCountryCodeAndChannelSlugLoader,
PreorderQuantityReservedByVariantChannelListingIdLoader,
StocksByProductVariantIdLoader,
StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader,
)
from ...warehouse.types import Stock
Expand Down Expand Up @@ -436,9 +437,15 @@ def resolve_stocks(
):
if address is not None:
country_code = address.country
return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader(
info.context
).load((root.node.id, country_code, root.channel_slug))
channle_slug = root.channel_slug
if channle_slug or country_code:
return StocksWithAvailableQuantityByProductVariantIdCountryCodeAndChannelLoader( # noqa: E501
info.context
).load(
(root.node.id, country_code, root.channel_slug)
)
else:
return StocksByProductVariantIdLoader(info.context).load(root.node.id)

@staticmethod
@load_site_callback
Expand Down
22 changes: 21 additions & 1 deletion saleor/graphql/shipping/dataloaders.py
@@ -1,6 +1,6 @@
from collections import defaultdict

from django.db.models import Exists, F, OuterRef
from django.db.models import Exists, F, OuterRef, Q

from ...channel.models import Channel
from ...shipping.models import (
Expand Down Expand Up @@ -220,3 +220,23 @@ def map_channels(channels):
.load_many({pk for pk, _ in channel_and_zone_is_pairs})
.then(map_channels)
)


class ShippingZonesByCountryLoader(DataLoader):
context_key = "shippingzones_by_country"

def batch_load(self, keys):
lookup = Q()
for key in keys:
lookup |= Q(countries__contains=key)
shipping_zones = ShippingZone.objects.using(
self.database_connection_name
).filter(lookup)

shipping_zones_by_country = defaultdict(list)
for shipping_zone in shipping_zones:
for country_code in keys:
if country_code in shipping_zone.countries:
shipping_zones_by_country[country_code].append(shipping_zone)

return [shipping_zones_by_country[key] for key in keys]

0 comments on commit 8f80dfa

Please sign in to comment.