Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 3.14.0 #419

Merged
merged 4 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
Releases
========

UNRELEASED
----------
3.14.0 (2024-03-21)
-------------------

* `#415 <https://github.com/pytest-dev/pytest-mock/pull/415>`_: ``MockType`` and ``AsyncMockType`` can be imported from ``pytest_mock`` for type annotation purposes.

* `#420 <https://github.com/pytest-dev/pytest-mock/issues/420>`_: Fixed a regression which would cause ``mocker.patch.object`` to not being properly cleared between tests.


3.13.0 (2024-03-21)
-------------------
Expand Down
62 changes: 24 additions & 38 deletions src/pytest_mock/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import builtins
import functools
import inspect
import sys
import unittest.mock
import warnings
from dataclasses import dataclass
Expand Down Expand Up @@ -30,16 +29,12 @@

_T = TypeVar("_T")

if sys.version_info >= (3, 8):
AsyncMockType = unittest.mock.AsyncMock
MockType = Union[
unittest.mock.MagicMock,
unittest.mock.AsyncMock,
unittest.mock.NonCallableMagicMock,
]
else:
AsyncMockType = Any
MockType = Union[unittest.mock.MagicMock, unittest.mock.NonCallableMagicMock]
AsyncMockType = unittest.mock.AsyncMock
MockType = Union[
unittest.mock.MagicMock,
unittest.mock.AsyncMock,
unittest.mock.NonCallableMagicMock,
]


class PytestMockWarning(UserWarning):
Expand All @@ -54,36 +49,37 @@ class MockCacheItem:

@dataclass
class MockCache:
"""
Cache MagicMock and Patcher instances so we can undo them later.
"""

cache: List[MockCacheItem] = field(default_factory=list)

def find(self, mock: MockType) -> MockCacheItem:
the_mock = next(
(mock_item for mock_item in self.cache if mock_item.mock == mock), None
)
if the_mock is None:
raise ValueError("This mock object is not registered")
return the_mock
def _find(self, mock: MockType) -> MockCacheItem:
for mock_item in self.cache:
if mock_item.mock is mock:
return mock_item
raise ValueError("This mock object is not registered")

def add(self, mock: MockType, **kwargs: Any) -> MockCacheItem:
try:
return self.find(mock)
except ValueError:
self.cache.append(MockCacheItem(mock=mock, **kwargs))
return self.cache[-1]
self.cache.append(MockCacheItem(mock=mock, **kwargs))
return self.cache[-1]

def remove(self, mock: MockType) -> None:
mock_item = self.find(mock)
mock_item = self._find(mock)
if mock_item.patch:
mock_item.patch.stop()
self.cache.remove(mock_item)

def clear(self) -> None:
for mock_item in reversed(self.cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self.cache.clear()

def __iter__(self) -> Iterator[MockCacheItem]:
return iter(self.cache)

def __reversed__(self) -> Iterator[MockCacheItem]:
return reversed(self.cache)


class MockerFixture:
"""
Expand Down Expand Up @@ -154,19 +150,13 @@ def stopall(self) -> None:
Stop all patchers started by this fixture. Can be safely called multiple
times.
"""
for mock_item in reversed(self._mock_cache):
if mock_item.patch is not None:
mock_item.patch.stop()
self._mock_cache.clear()

def stop(self, mock: unittest.mock.MagicMock) -> None:
"""
Stops a previous patch or spy call by passing the ``MagicMock`` object
returned by it.
"""
mock_item = self._mock_cache.find(mock)
if mock_item.patch:
mock_item.patch.stop()
self._mock_cache.remove(mock)

def spy(self, obj: object, name: str) -> MockType:
Expand Down Expand Up @@ -271,17 +261,13 @@ def _start_patch(
# check if `mocked` is actually a mock object, as depending on autospec or target
# parameters `mocked` can be anything
if hasattr(mocked, "__enter__") and warn_on_mock_enter:
if sys.version_info >= (3, 8):
depth = 5
else:
depth = 4
mocked.__enter__.side_effect = lambda: warnings.warn(
"Mocks returned by pytest-mock do not need to be used as context managers. "
"The mocker fixture automatically undoes mocking at the end of a test. "
"This warning can be ignored if it was triggered by mocking a context manager. "
"https://pytest-mock.readthedocs.io/en/latest/remarks.html#usage-as-context-manager",
PytestMockWarning,
stacklevel=depth,
stacklevel=5,
)
return mocked

Expand Down
65 changes: 33 additions & 32 deletions tests/test_pytest_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Generator
from typing import Tuple
from typing import Type
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

import pytest
Expand All @@ -22,14 +23,9 @@
platform.python_implementation() == "PyPy", reason="could not make it work on pypy"
)

# Python 3.8 changed the output formatting (bpo-35500), which has been ported to mock 3.0
NEW_FORMATTING = sys.version_info >= (3, 8)
# Python 3.11.7 changed the output formatting, https://github.com/python/cpython/issues/111019
NEWEST_FORMATTING = sys.version_info >= (3, 11, 7)

if sys.version_info[:2] >= (3, 8):
from unittest.mock import AsyncMock


@pytest.fixture
def needs_assert_rewrite(pytestconfig):
Expand Down Expand Up @@ -173,12 +169,7 @@ def test_mock_patch_dict_resetall(mocker: MockerFixture) -> None:
"NonCallableMock",
"PropertyMock",
"sentinel",
pytest.param(
"seal",
marks=pytest.mark.skipif(
sys.version_info < (3, 7), reason="seal is present on 3.7 and above"
),
),
"seal",
],
)
def test_mocker_aliases(name: str, pytestconfig: Any) -> None:
Expand Down Expand Up @@ -243,10 +234,8 @@ def __test_failure_message(self, mocker: MockerFixture, **kwargs: Any) -> None:
expected_name = kwargs.get("name") or "mock"
if NEWEST_FORMATTING:
msg = "expected call not found.\nExpected: {0}()\n Actual: not called."
elif NEW_FORMATTING:
msg = "expected call not found.\nExpected: {0}()\nActual: not called."
else:
msg = "Expected call: {0}()\nNot called"
msg = "expected call not found.\nExpected: {0}()\nActual: not called."
expected_message = msg.format(expected_name)
stub = mocker.stub(**kwargs)
with pytest.raises(AssertionError, match=re.escape(expected_message)):
Expand All @@ -259,10 +248,6 @@ def test_failure_message_with_no_name(self, mocker: MagicMock) -> None:
def test_failure_message_with_name(self, mocker: MagicMock, name: str) -> None:
self.__test_failure_message(mocker, name=name)

@pytest.mark.skipif(
sys.version_info[:2] < (3, 8),
reason="This Python version doesn't have `AsyncMock`.",
)
def test_async_stub_type(self, mocker: MockerFixture) -> None:
assert isinstance(mocker.async_stub(), AsyncMock)

Expand Down Expand Up @@ -892,17 +877,11 @@ def test(mocker):
"""
)
result = testdir.runpytest("-s")
if NEW_FORMATTING:
expected_lines = [
"*AssertionError: expected call not found.",
"*Expected: mock('', bar=4)",
"*Actual: mock('fo')",
]
else:
expected_lines = [
"*AssertionError: Expected call: mock('', bar=4)*",
"*Actual call: mock('fo')*",
]
expected_lines = [
"*AssertionError: expected call not found.",
"*Expected: mock('', bar=4)",
"*Actual: mock('fo')",
]
expected_lines += [
"*pytest introspection follows:*",
"*Args:",
Expand All @@ -918,9 +897,6 @@ def test(mocker):
result.stdout.fnmatch_lines(expected_lines)


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="AsyncMock is present on 3.8 and above"
)
@pytest.mark.usefixtures("needs_assert_rewrite")
def test_detailed_introspection_async(testdir: Any) -> None:
"""Check that the "mock_use_standalone" is being used."""
Expand Down Expand Up @@ -1288,3 +1264,28 @@ def foo(self):
mocker.stop(spy)
assert un_spy.foo() == 42
assert spy.call_count == 1


def test_stop_multiple_patches(mocker: MockerFixture) -> None:
"""Regression for #420."""

class Class1:
@staticmethod
def get():
return 1

class Class2:
@staticmethod
def get():
return 2

def handle_get():
return 3

mocker.patch.object(Class1, "get", handle_get)
mocker.patch.object(Class2, "get", handle_get)

mocker.stopall()

assert Class1.get() == 1
assert Class2.get() == 2