Skip to content

Commit

Permalink
Adjust tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maarcingebala committed Apr 23, 2024
1 parent 129c9aa commit eac4ea7
Show file tree
Hide file tree
Showing 15 changed files with 228 additions and 159 deletions.
25 changes: 19 additions & 6 deletions saleor/checkout/fetch.py
Expand Up @@ -64,18 +64,15 @@ def get_promotion_discounts(self) -> list["CheckoutLineDiscount"]:

@dataclass
class CheckoutInfo:
# compare=False used to avoid comparing the manager attribute as it's irrelevant
# for the equality checkout infos; this is used in tests
manager: "PluginsManager" = field(compare=False)
lines: Iterable[CheckoutLineInfo] = field(compare=False)

checkout: "Checkout"
user: Optional["User"]
channel: "Channel"
billing_address: Optional["Address"]
shipping_address: Optional["Address"]
tax_configuration: "TaxConfiguration"
shipping_channel_listings: Iterable["ShippingMethodChannelListing"]
lines: Iterable[CheckoutLineInfo]
shipping_channel_listings: list["ShippingMethodChannelListing"]
shipping_method: Optional["ShippingMethod"] = None
collection_point: Optional["Warehouse"] = None
voucher: Optional["Voucher"] = None
Expand All @@ -97,7 +94,7 @@ def all_shipping_methods(self) -> list["ShippingMethodData"]:
initialize_shipping_method_active_status(all_methods, excluded_methods)
return all_methods

@cached_property
@property
def valid_pick_up_points(self) -> Iterable["Warehouse"]:
from .utils import get_valid_collection_points_for_checkout

Expand Down Expand Up @@ -589,3 +586,19 @@ def get_all_shipping_methods_list(
),
)
)


def update_delivery_method_lists_for_checkout_info(
checkout_info: "CheckoutInfo",
shipping_method: Optional["ShippingMethod"],
collection_point: Optional["Warehouse"],
shipping_address: Optional["Address"],
lines: Iterable[CheckoutLineInfo],
shipping_channel_listings: Iterable[ShippingMethodChannelListing],
):
# Update checkout info fields with new data
checkout_info.shipping_method = shipping_method
checkout_info.collection_point = collection_point
checkout_info.shipping_address = shipping_address
checkout_info.lines = lines
checkout_info.shipping_channel_listings = list(shipping_channel_listings)
103 changes: 55 additions & 48 deletions saleor/checkout/tests/test_checkout.py
Expand Up @@ -33,7 +33,6 @@
DeliveryMethodBase,
fetch_checkout_info,
fetch_checkout_lines,
get_delivery_method_info,
)
from ..models import Checkout, CheckoutLine
from ..utils import (
Expand Down Expand Up @@ -317,17 +316,7 @@ def test_get_discount_for_checkout_value_entire_order_voucher(
"saleor.checkout.base_calculations.base_checkout_subtotal",
lambda *args: subtotal,
)
checkout_info = CheckoutInfo(
checkout=checkout,
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
delivery_method_info=get_delivery_method_info(None, None),
all_shipping_methods=[],
)
manager = get_plugins_manager(allow_replica=False)
lines = [
CheckoutLineInfo(
line=line,
Expand All @@ -342,7 +331,17 @@ def test_get_discount_for_checkout_value_entire_order_voucher(
)
for line in checkout_with_items.lines.all()
]
manager = get_plugins_manager(allow_replica=False)
checkout_info = CheckoutInfo(
checkout=checkout,
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
manager=manager,
lines=lines,
shipping_channel_listings=[],
)

# when
discount = get_voucher_discount_for_checkout(
Expand Down Expand Up @@ -481,17 +480,6 @@ def test_get_discount_for_checkout_value_specific_product_voucher(
"saleor.checkout.base_calculations.base_checkout_subtotal",
lambda *args: subtotal,
)
checkout_info = CheckoutInfo(
checkout=checkout,
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
delivery_method_info=get_delivery_method_info(None, None),
all_shipping_methods=[],
)
lines = [
CheckoutLineInfo(
line=line,
Expand All @@ -507,6 +495,17 @@ def test_get_discount_for_checkout_value_specific_product_voucher(
for line in checkout_with_items.lines.all()
]
manager = get_plugins_manager(allow_replica=False)
checkout_info = CheckoutInfo(
checkout=checkout,
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
manager=manager,
lines=lines,
shipping_channel_listings=[],
)

# when
discount = get_voucher_discount_for_checkout(
Expand Down Expand Up @@ -596,18 +595,18 @@ def test_get_discount_for_checkout_entire_order_voucher_not_applicable(
"saleor.checkout.base_calculations.base_checkout_subtotal",
lambda *args: subtotal,
)
manager = get_plugins_manager(allow_replica=False)
checkout_info = CheckoutInfo(
checkout=checkout,
delivery_method_info=None,
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
manager=manager,
lines=[],
shipping_channel_listings=[],
)
manager = get_plugins_manager(allow_replica=False)
with pytest.raises(NotApplicable):
get_voucher_discount_for_checkout(manager, voucher, checkout_info, [], None)

Expand Down Expand Up @@ -777,14 +776,14 @@ def test_get_discount_for_checkout_specific_products_voucher_not_applicable(
checkout = Mock(quantity=total_quantity, spec=Checkout, channel=channel_USD)
checkout_info = CheckoutInfo(
checkout=checkout,
delivery_method_info=get_delivery_method_info(None, None),
shipping_address=None,
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
manager=manager,
lines=[],
shipping_channel_listings=[],
)
with pytest.raises(NotApplicable):
get_voucher_discount_for_checkout(manager, voucher, checkout_info, [], None)
Expand Down Expand Up @@ -844,7 +843,6 @@ def test_get_discount_for_checkout_shipping_voucher(
monkeypatch,
channel_USD,
shipping_method,
shipping_method_data,
):
manager = get_plugins_manager(allow_replica=False)
tax = Decimal("1.23")
Expand Down Expand Up @@ -883,15 +881,14 @@ def test_get_discount_for_checkout_shipping_voucher(
checkout_info = CheckoutInfo(
checkout=checkout,
shipping_address=shipping_address,
delivery_method_info=get_delivery_method_info(
shipping_method_data, shipping_address
),
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
manager=manager,
shipping_method=shipping_method,
lines=[],
shipping_channel_listings=shipping_method.channel_listings.all(),
)

discount = get_voucher_discount_for_checkout(
Expand All @@ -901,7 +898,7 @@ def test_get_discount_for_checkout_shipping_voucher(


def test_get_discount_for_checkout_shipping_voucher_all_countries(
monkeypatch, channel_USD, shipping_method, shipping_method_data
monkeypatch, channel_USD, shipping_method
):
subtotal = Money(100, "USD")
monkeypatch.setattr(
Expand Down Expand Up @@ -936,14 +933,15 @@ def test_get_discount_for_checkout_shipping_voucher_all_countries(
manager = get_plugins_manager(allow_replica=False)
checkout_info = CheckoutInfo(
checkout=checkout,
delivery_method_info=get_delivery_method_info(shipping_method_data),
shipping_address=Mock(spec=Address, country=Mock(code="PL")),
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
shipping_method=shipping_method,
shipping_channel_listings=shipping_method.channel_listings.all(),
manager=manager,
lines=[],
)
discount = get_voucher_discount_for_checkout(
manager, voucher, checkout_info, [], None
Expand Down Expand Up @@ -980,18 +978,18 @@ def test_get_discount_for_checkout_shipping_voucher_limited_countries(
discount=Money(50, channel_USD.currency_code),
)

manager = get_plugins_manager(allow_replica=False)
checkout_info = CheckoutInfo(
checkout=checkout,
delivery_method_info=get_delivery_method_info(None, None),
shipping_address=Mock(spec=Address, country=Mock(code="PL")),
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
manager=manager,
lines=[],
shipping_channel_listings=[],
)
manager = get_plugins_manager(allow_replica=False)
with pytest.raises(NotApplicable):
get_voucher_discount_for_checkout(
manager,
Expand Down Expand Up @@ -1112,13 +1110,21 @@ def test_get_discount_for_checkout_shipping_voucher_not_applicable(
monkeypatch.setattr(
"saleor.checkout.utils.is_shipping_required", lambda lines: is_shipping_required
)

if shipping_method_data:
shipping_channel_listings = shipping_method.channel_listings.all()
else:
shipping_method = None
shipping_channel_listings = []

checkout = Mock(
is_shipping_required=Mock(return_value=is_shipping_required),
shipping_method=shipping_method,
shipping_address=Mock(spec=Address, country=Mock(code="PL")),
quantity=total_quantity,
spec=Checkout,
channel=channel_USD,
get_value_from_private_metadata=Mock(return_value=None),
)

voucher = Voucher.objects.create(
Expand All @@ -1137,14 +1143,15 @@ def test_get_discount_for_checkout_shipping_voucher_not_applicable(
)
checkout_info = CheckoutInfo(
checkout=checkout,
delivery_method_info=get_delivery_method_info(shipping_method_data),
shipping_address=Mock(spec=Address, country=Mock(code="PL")),
billing_address=None,
channel=channel_USD,
user=None,
tax_configuration=channel_USD.tax_configuration,
valid_pick_up_points=[],
all_shipping_methods=[],
manager=manager,
lines=[],
shipping_method=shipping_method,
shipping_channel_listings=shipping_channel_listings,
)
with pytest.raises(NotApplicable) as e:
get_voucher_discount_for_checkout(
Expand Down
20 changes: 19 additions & 1 deletion saleor/checkout/utils.py
Expand Up @@ -10,6 +10,7 @@
from prices import Money

from ..account.models import User
from ..checkout.fetch import update_delivery_method_lists_for_checkout_info
from ..core.exceptions import ProductNotPublished
from ..core.taxes import zero_taxed_money
from ..core.utils.promo_code import (
Expand Down Expand Up @@ -366,7 +367,14 @@ def change_shipping_address_in_checkout(
if remove and checkout.shipping_address:
checkout.shipping_address.delete()
checkout.shipping_address = address
checkout_info.shipping_address = address
update_delivery_method_lists_for_checkout_info(
checkout_info=checkout_info,
shipping_method=checkout_info.checkout.shipping_method,
collection_point=checkout_info.checkout.collection_point,
shipping_address=address,
lines=lines,
shipping_channel_listings=shipping_channel_listings,
)
updated_fields = ["shipping_address", "last_change"]
return updated_fields

Expand Down Expand Up @@ -850,6 +858,16 @@ def clear_delivery_method(checkout_info: "CheckoutInfo"):
checkout.collection_point = None
checkout.shipping_method = None
checkout_info.shipping_method = None

update_delivery_method_lists_for_checkout_info(
checkout_info=checkout_info,
shipping_method=None,
collection_point=None,
shipping_address=checkout_info.shipping_address,
lines=checkout_info.lines,
shipping_channel_listings=checkout_info.shipping_channel_listings,
)

delete_external_shipping_id(checkout=checkout)
checkout.save(
update_fields=[
Expand Down
10 changes: 10 additions & 0 deletions saleor/graphql/checkout/mutations/checkout_add_promo_code.py
Expand Up @@ -5,6 +5,7 @@
from ....checkout.fetch import (
fetch_checkout_info,
fetch_checkout_lines,
update_delivery_method_lists_for_checkout_info,
)
from ....checkout.utils import add_promo_code_to_checkout, invalidate_checkout_prices
from ....webhook.event_types import WebhookEventAsyncType
Expand Down Expand Up @@ -102,6 +103,15 @@ def perform_mutation( # type: ignore[override]
promo_code,
)

update_delivery_method_lists_for_checkout_info(
checkout_info=checkout_info,
shipping_method=checkout_info.checkout.shipping_method,
collection_point=checkout_info.checkout.collection_point,
shipping_address=checkout_info.shipping_address,
lines=lines,
shipping_channel_listings=shipping_channel_listings,
)

update_checkout_shipping_method_if_invalid(checkout_info, lines)
invalidate_checkout_prices(
checkout_info,
Expand Down
10 changes: 10 additions & 0 deletions saleor/graphql/checkout/mutations/checkout_lines_add.py
Expand Up @@ -4,6 +4,7 @@
from ....checkout.fetch import (
fetch_checkout_info,
fetch_checkout_lines,
update_delivery_method_lists_for_checkout_info,
)
from ....checkout.utils import add_variants_to_checkout, invalidate_checkout_prices
from ....warehouse.reservations import get_reservation_length, is_reservation_enabled
Expand Down Expand Up @@ -162,6 +163,15 @@ def clean_input(
)

lines, _ = fetch_checkout_lines(checkout)
shipping_channel_listings = checkout.channel.shipping_method_listings.all()
update_delivery_method_lists_for_checkout_info(
checkout_info=checkout_info,
shipping_method=checkout_info.checkout.shipping_method,
collection_point=checkout_info.checkout.collection_point,
shipping_address=checkout_info.shipping_address,
lines=lines,
shipping_channel_listings=shipping_channel_listings,
)
return lines

@classmethod
Expand Down

0 comments on commit eac4ea7

Please sign in to comment.