forked from facebook/TestSlide
/
mock_callable.py
1271 lines (1097 loc) · 43.7 KB
/
mock_callable.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import functools
import inspect
import platform
import re
from inspect import Traceback
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
Union,
)
from unittest.mock import Mock
from testslide.core.lib import _validate_return_type, _wrap_signature_and_type_validation
from testslide.core.strict_mock import StrictMock
from .lib import CoroutineValueError, _bail_if_private, _is_a_builtin
from .patch import _is_instance_method, _patch
if TYPE_CHECKING:
from .matchers import RegexMatches # noqa: F401
from .mock_constructor import _MockConstructorDSL # noqa: F401
def mock_callable(
target: Any,
method: str,
allow_private: bool = False,
# type_validation accepted values:
# * None: type validation will be enabled except if target is a StrictMock
# with disabled type validation
# * True: type validation will be enabled (regardless of target type)
# * False: type validation will be disabled
type_validation: Optional[bool] = None,
) -> "_MockCallableDSL":
caller_frame = inspect.currentframe().f_back.f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
return _MockCallableDSL(
target,
method,
caller_frame_info,
allow_private=allow_private,
type_validation=type_validation,
)
def mock_async_callable(
target: Union[type, str],
method: str,
callable_returns_coroutine: bool = False,
allow_private: bool = False,
type_validation: bool = True,
) -> "_MockAsyncCallableDSL":
caller_frame = inspect.currentframe().f_back # type: ignore
# loading the context ends up reading files from disk and that might block
# the event loop, so we don't do it.
caller_frame_info = inspect.getframeinfo(caller_frame, context=0) # type: ignore
return _MockAsyncCallableDSL(
target,
method,
caller_frame_info,
callable_returns_coroutine,
allow_private,
type_validation,
)
_unpatchers: List[Callable] = [] # noqa T484
def _default_register_assertion(assertion: Callable) -> None:
"""
This method must be redefined by the test framework using mock_callable().
It will be called when a new assertion is defined, passing a callable as an
argument that evaluates that assertion. Every defined assertion during a test
must be called after the test code ends, and before the test finishes.
"""
raise NotImplementedError("This method must be redefined by the test framework")
register_assertion = _default_register_assertion
_call_order_assertion_registered: bool = False
_received_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []
_expected_ordered_calls: List[Tuple[Any, str, "_BaseRunner"]] = []
def unpatch_all_callable_mocks() -> None:
"""
This method must be called after every test unconditionally to remove all
active mock_callable() patches.
"""
global register_assertion, _default_register_assertion, _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls
register_assertion = _default_register_assertion
_call_order_assertion_registered = False
del _received_ordered_calls[:]
del _expected_ordered_calls[:]
unpatch_exceptions = []
for unpatcher in _unpatchers:
try:
unpatcher()
except Exception as e:
unpatch_exceptions.append(e)
del _unpatchers[:]
if unpatch_exceptions:
raise RuntimeError(
"Exceptions raised when unpatching: {}".format(unpatch_exceptions)
)
def _is_setup() -> bool:
global register_assertion, _default_register_assertion
return register_assertion is not _default_register_assertion
def _format_target(target: Union[str, type]) -> str:
if hasattr(target, "__repr__"):
return repr(target)
else:
return "{}.{} instance with id {}".format(
target.__module__, type(target).__name__, id(target)
)
def _format_args(indent: int, *args: Any, **kwargs: Any) -> str:
indentation = " " * indent
s = ""
if args:
s += ("{}{}\n").format(indentation, args)
if kwargs:
s += indentation + "{"
if kwargs:
s += "\n"
for k in sorted(kwargs.keys()):
s += "{} {}={},\n".format(indentation, k, repr(kwargs[k]))
s += "{}".format(indentation)
s += "}\n"
return s
def _is_coroutine(obj: Any) -> bool:
if [int(re.sub(r"[^0-9]", "", x)) for x in platform.python_version_tuple()] < [
3,
11,
]:
return inspect.iscoroutine(obj) or isinstance(obj, asyncio.coroutines.CoroWrapper) # type: ignore
else:
return inspect.iscoroutine(obj)
def _is_coroutinefunction(func: Any) -> bool:
# We use asyncio.iscoroutinefunction over inspect because the next Cython version
# will return True from the asyncio variant over inspect which will return False
# FIXME We can not reliably introspect coroutine functions
# for builtins: https://bugs.python.org/issue38225
return asyncio.iscoroutinefunction(func) or _is_a_builtin(func)
##
## Exceptions
##
class UndefinedBehaviorForCall(BaseException):
"""
Raised when a mock receives a call for which no behavior was defined.
Inherits from BaseException to avoid being caught by tested code.
"""
class UnexpectedCallReceived(BaseException):
"""
Raised when a mock receives a call that it is configured not to accept.
Inherits from BaseException to avoid being caught by tested code.
"""
class UnexpectedCallArguments(BaseException):
"""
Raised when a mock receives a call with unexpected arguments.
Inherits from BaseException to avoid being caught by tested code.
"""
class NotACoroutine(BaseException):
"""
Raised when a mock that requires a coroutine is not mocked with one.
Inherits from BaseException to avoid being caught by tested code.
"""
##
## Runners
##
class _BaseRunner:
TYPE_VALIDATION = True
def __init__(
self, target: Any, method: str, original_callable: Union[Callable, Mock]
) -> None:
self.target = target
self.method = method
self.original_callable = original_callable
self.accepted_args: Optional[Tuple[Any, Any]] = None
self._call_count: int = 0
self._max_calls: Optional[int] = None
self._has_order_assertion = False
self._accept_partial_call = False
def register_call(self, *args: Any, **kwargs: Any) -> None:
global _received_ordered_calls
if self._has_order_assertion:
_received_ordered_calls.append((self.target, self.method, self))
self.inc_call_count()
@property
def call_count(self) -> int:
return self._call_count
@property
def max_calls(self) -> Optional[int]:
return self._max_calls
def _set_max_calls(self, times: int) -> None:
if not self._max_calls or times < self._max_calls:
self._max_calls = times
def inc_call_count(self) -> None:
self._call_count += 1
if self.max_calls and self._call_count > self.max_calls:
raise UnexpectedCallReceived(
(
"Unexpected call received.\n"
"{}, {}:\n"
" expected to receive at most {} calls with {}"
" but received an extra call."
).format(
_format_target(self.target),
repr(self.method),
self.max_calls,
self._args_message(),
)
)
def add_accepted_args(
self,
_accept_partial_call: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
self.accepted_args = (args, kwargs)
self._accept_partial_call = _accept_partial_call
def can_accept_args(self, *args: Any, **kwargs: Any) -> bool:
if self.accepted_args:
if self._accept_partial_call:
args_match = all(
any(elem == arg for arg in args) for elem in self.accepted_args[0]
)
kwargs_match = all(
elem in kwargs.keys()
and kwargs[elem] == self.accepted_args[1][elem]
for elem in self.accepted_args[1].keys()
)
return args_match and kwargs_match
else:
return self.accepted_args == (args, kwargs)
else:
return True
def _args_message(self) -> str:
if self.accepted_args:
return "arguments:\n{}".format(
_format_args(2, *self.accepted_args[0], **self.accepted_args[1])
)
else:
return "any arguments "
def add_exact_calls_assertion(self, times: int) -> None:
self._set_max_calls(times)
def assertion() -> None:
if times != self.call_count:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called exactly {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_at_least_calls_assertion(self, times: int) -> None:
def assertion() -> None:
if self.call_count < times:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called at least {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_at_most_calls_assertion(self, times: int) -> None:
self._set_max_calls(times)
def assertion() -> None:
if not self.call_count or self.call_count > times:
raise AssertionError(
(
"calls did not match assertion.\n"
"{}, {}:\n"
" expected: called at most {} time(s) with {}"
" received: {} call(s)"
).format(
_format_target(self.target),
repr(self.method),
times,
self._args_message(),
self.call_count,
)
)
register_assertion(assertion)
def add_call_order_assertion(self) -> None:
global _call_order_assertion_registered, _received_ordered_calls, _expected_ordered_calls
if not _call_order_assertion_registered:
def assertion() -> None:
if _received_ordered_calls != _expected_ordered_calls:
raise AssertionError(
(
"calls did not match assertion.\n"
"\n"
"These calls were expected to have happened in order:\n"
"\n"
"{}\n"
"\n"
"but these calls were made:\n"
"\n"
"{}"
).format(
"\n".join(
(
" {}, {} with {}".format(
_format_target(target),
repr(method),
runner._args_message().rstrip(),
)
for target, method, runner in _expected_ordered_calls
)
),
"\n".join(
(
" {}, {} with {}".format(
_format_target(target),
repr(method),
runner._args_message().rstrip(),
)
for target, method, runner in _received_ordered_calls
)
),
)
)
register_assertion(assertion)
_call_order_assertion_registered = True
_expected_ordered_calls.append((self.target, self.method, self))
self._has_order_assertion = True
class _Runner(_BaseRunner):
def run(self, *args: Any, **kwargs: Any) -> None:
super().register_call(*args, **kwargs)
class _AsyncRunner(_BaseRunner):
async def run(self, *args: Any, **kwargs: Any) -> None:
super().register_call(*args, **kwargs)
class _ReturnValueRunner(_Runner):
def __init__(
self,
target: Any,
method: str,
original_callable: Union[Callable, Mock],
value: Optional[Any],
allow_coro: bool = False,
) -> None:
super().__init__(target, method, original_callable)
if not allow_coro and _is_coroutine(value):
raise CoroutineValueError()
self.return_value = value
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super().run(*args, **kwargs)
return self.return_value
class _ReturnValuesRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
values_list: List[Any],
allow_coro: bool = False,
) -> None:
super(_ReturnValuesRunner, self).__init__(target, method, original_callable)
# Reverse original list for popping efficiency
if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
raise CoroutineValueError()
self.values_list = list(reversed(values_list))
def run(self, *args: Any, **kwargs: Any) -> Any:
super(_ReturnValuesRunner, self).run(*args, **kwargs)
if self.values_list:
return self.values_list.pop()
else:
raise UndefinedBehaviorForCall("No more values to return!")
class _YieldValuesRunner(_Runner):
TYPE_VALIDATION = False
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
values_list: List[Any],
allow_coro: bool = False,
) -> None:
super(_YieldValuesRunner, self).__init__(target, method, original_callable)
self.values_list = values_list
self.index = 0
if not allow_coro and any(_is_coroutine(rv) for rv in values_list):
raise CoroutineValueError()
def __iter__(self) -> "_YieldValuesRunner":
return self
def __next__(self) -> Any:
try:
item = self.values_list[self.index]
except IndexError:
raise StopIteration()
self.index += 1
return item
def run(self, *args: Any, **kwargs: Any) -> "_YieldValuesRunner": # type: ignore
super(_YieldValuesRunner, self).run(*args, **kwargs)
return self
class _RaiseRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
exception: BaseException,
) -> None:
super(_RaiseRunner, self).__init__(target, method, original_callable)
self.exception = exception
def run(self, *args: Any, **kwargs: Any) -> None:
super(_RaiseRunner, self).run(*args, **kwargs)
raise self.exception
class _ImplementationRunner(_Runner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
new_implementation: Callable,
allow_coro: bool = False,
) -> None:
super(_ImplementationRunner, self).__init__(target, method, original_callable)
self.new_implementation = new_implementation
self._allow_coro = allow_coro
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super(_ImplementationRunner, self).run(*args, **kwargs)
new_impl = self.new_implementation(*args, **kwargs)
if not self._allow_coro and _is_coroutine(new_impl):
raise CoroutineValueError()
return new_impl
class _AsyncImplementationRunner(_AsyncRunner):
def __init__(
self,
target: Union[type, str],
method: str,
original_callable: Union[Callable[..., Any], Mock],
new_implementation: Callable,
) -> None:
super().__init__(target, method, original_callable)
self.new_implementation = new_implementation
async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
await super().run(*args, **kwargs)
coro = self.new_implementation(*args, **kwargs)
if not _is_coroutine(coro):
raise NotACoroutine(
f"Function did not return a coroutine.\n"
f"{self.new_implementation} must return a coroutine."
)
return await coro
class _CallOriginalRunner(_Runner):
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
super(_CallOriginalRunner, self).run(*args, **kwargs)
return self.original_callable(*args, **kwargs)
class _AsyncCallOriginalRunner(_AsyncRunner):
async def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
await super().run(*args, **kwargs)
return await self.original_callable(*args, **kwargs)
##
## Callable Mocks
##
class _CallableMock:
def __init__(
self,
target: Any,
method: str,
caller_frame_info: Traceback,
is_async: bool = False,
callable_returns_coroutine: bool = False,
# type_validation accepted values:
# * None: type validation will be enabled except if target is a StrictMock
# with disabled type validation
# * True: type validation will be enabled (regardless of target type)
# * False: type validation will be disabled
type_validation: Optional[bool] = None,
) -> None:
self.target = target
self.method = method
self.runners: List[_BaseRunner] = []
self.is_async = is_async
self.callable_returns_coroutine = callable_returns_coroutine
self.type_validation = type_validation or type_validation is None
self.caller_frame_info = caller_frame_info
if type_validation is None and isinstance(target, StrictMock):
# If type validation is enabled on the specific call
# but the StrictMock has type validation disabled then
# type validation should be disabled
self.type_validation = target._type_validation
def _get_runner(self, *args: Any, **kwargs: Any) -> Any:
for runner in self.runners:
if runner.can_accept_args(*args, **kwargs):
return runner
return None
def _validate_return_type(self, runner: _BaseRunner, value: Any) -> None:
if self.type_validation and runner.TYPE_VALIDATION:
if runner.original_callable is not None:
_validate_return_type(
runner.original_callable,
value,
self.caller_frame_info,
self.callable_returns_coroutine,
)
elif isinstance(runner.target, StrictMock):
_validate_return_type(
getattr(runner.target, runner.method), value, self.caller_frame_info
)
def __call__(self, *args: Any, **kwargs: Any) -> Optional[Any]:
runner = self._get_runner(*args, **kwargs)
if runner:
if self.is_async:
if isinstance(runner, _AsyncRunner):
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
value = await runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
value = async_wrapper(*args, **kwargs)
else:
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
value = runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
value = async_wrapper(*args, **kwargs)
else:
value = runner.run(*args, **kwargs)
self._validate_return_type(runner, value)
return value
else:
ex_msg = (
"{}, {}:\n"
" Received call:\n"
"{}"
" But no behavior was defined for it."
).format(
_format_target(self.target),
repr(self.method),
_format_args(2, *args, **kwargs),
)
if self._registered_calls:
ex_msg += "\n These are the registered calls:\n" "{}".format(
"".join(
_format_args(2, *reg_args, **reg_kwargs)
for reg_args, reg_kwargs in self._registered_calls
)
)
raise UnexpectedCallArguments(ex_msg)
raise UndefinedBehaviorForCall(ex_msg)
@property
def _registered_calls(self) -> Any:
return [runner.accepted_args for runner in self.runners if runner.accepted_args]
##
## Support
##
class _MockCallableDSL:
CALLABLE_MOCKS: Dict[
Union[int, Tuple[int, str]], Union[Callable[[Type[object]], Any]]
] = {}
_NAME: str = "mock_callable"
def _validate_patch(
self,
name: str = "mock_callable",
other_name: str = "mock_async_callable",
coroutine_function: bool = False,
callable_returns_coroutine: bool = False,
) -> None:
if self._method == "__new__":
raise ValueError(
f"Mocking __new__ is not allowed with {name}(), please use "
"mock_constructor()."
)
_bail_if_private(self._method, self.allow_private)
if isinstance(self._target, StrictMock):
template_value = getattr(self._target._template, self._method, None)
if template_value and callable(template_value):
if not coroutine_function and asyncio.iscoroutinefunction(
template_value
):
raise ValueError(
f"{name}() can not be used with coroutine functions.\n"
f"The attribute '{self._method}' of the template class "
f"of {self._target} is a coroutine function. You can "
f"use {other_name}() instead."
)
if coroutine_function and not (
_is_coroutinefunction(template_value) or callable_returns_coroutine
):
raise ValueError(
f"{name}() can not be used with non coroutine "
"functions.\n"
f"The attribute '{self._method}' of the template class "
f"of {self._target} is not a coroutine function. You "
f"can use {other_name}() instead."
)
else:
if inspect.isclass(self._target) and _is_instance_method(
self._target, self._method
):
raise ValueError(
"Patching an instance method at the class is not supported: "
"bugs are easy to introduce, as patch is not scoped for an "
"instance, which can potentially even break class behavior; "
"assertions on calls are ambiguous (for every instance or one "
"global assertion?)."
)
original_callable = getattr(self._target, self._method)
if not callable(original_callable):
raise ValueError(
f"{name}() can only be used with callable attributes and "
f"{repr(original_callable)} is not."
)
if inspect.isclass(original_callable):
raise ValueError(
f"{name}() can not be used with classes: "
f"{repr(original_callable)}. Perhaps you want to use "
"mock_constructor() instead."
)
if not coroutine_function and asyncio.iscoroutinefunction(
original_callable
):
raise ValueError(
f"{name}() can not be used with coroutine functions.\n"
f"{original_callable} is a coroutine function. You can use "
f"{other_name}() instead."
)
if coroutine_function and not (
_is_coroutinefunction(original_callable) or callable_returns_coroutine
):
raise ValueError(
f"{name}() can not be used with non coroutine functions.\n"
f"{original_callable} is not a coroutine function. You can "
f"use {other_name}() instead."
)
def _patch(
self, new_value: Union[Callable, _CallableMock]
) -> Union[Tuple[Callable, Callable], Tuple[Mock, Callable], Tuple[None, Callable]]:
self._validate_patch()
if isinstance(self._target, StrictMock):
original_callable = None
else:
original_callable = getattr(self._target, self._method)
new_value = _wrap_signature_and_type_validation(
new_value,
self._target,
self._method,
self.type_validation or self.type_validation is None,
)
if isinstance(self._target, (Mock, StrictMock)) or not hasattr(
self._target, "__slots__"
):
restore = self._method in self._target.__dict__
restore_value = self._target.__dict__.get(self._method, None)
else:
restore = self._method in self._target.__slots__
restore_value = getattr(self._target, self._method)
if inspect.isclass(self._target):
new_value = staticmethod(new_value) # type: ignore
unpatcher = _patch(
self._target, self._method, new_value, restore, restore_value
)
return original_callable, unpatcher
def _get_callable_mock(self) -> _CallableMock:
return _CallableMock(
self._original_target,
self._method,
self.caller_frame_info,
type_validation=self.type_validation,
)
def __init__(
self,
target: Any,
method: str,
caller_frame_info: Traceback,
callable_mock: Union[Callable[[Type[object]], Any], _CallableMock, None] = None,
original_callable: Optional[Callable] = None,
allow_private: bool = False,
type_validation: Optional[bool] = None,
) -> None:
if not _is_setup():
raise RuntimeError(
"TestSlide was not correctly setup before usage!\n"
"This error happens when mock_callable, mock_async_callable or "
"mock_constructor are attempted to be used without correct "
"test framework integration, meaning unpatching and "
"assertions will not work as expected.\n"
"A common scenario for this is when testslide.TestCase is "
"subclassed with setUp() overridden but super().setUp() was not "
"called."
)
self._original_target = target
self._method = method
self._runner: Optional[_BaseRunner] = None
self._next_runner_accepted_args: Any = None
self.allow_private = allow_private
self.type_validation = type_validation
self.caller_frame_info = caller_frame_info
self._allow_coro = False
self._accept_partial_call = False
if isinstance(target, str):
from testslide import _importer
self._target = _importer(target)
else:
self._target = target
target_method_id = (id(self._target), method)
if target_method_id not in self.CALLABLE_MOCKS:
if not callable_mock:
patch = True
callable_mock = self._get_callable_mock()
else:
patch = False
self.CALLABLE_MOCKS[target_method_id] = callable_mock
self._callable_mock = callable_mock
def del_callable_mock() -> None:
del self.CALLABLE_MOCKS[target_method_id]
_unpatchers.append(del_callable_mock)
if patch:
original_callable, unpatcher = self._patch(callable_mock)
_unpatchers.append(unpatcher)
self._original_callable = original_callable
callable_mock.original_callable = original_callable # type: ignore
else:
self._callable_mock = self.CALLABLE_MOCKS[target_method_id]
self._original_callable = self._callable_mock.original_callable # type: ignore
def _add_runner(self, runner: _BaseRunner) -> None:
if self._runner:
raise ValueError(
"Can't define more than one behavior using the same "
"self.mock_callable() chain. Please call self.mock_callable() again "
"one time for each new behavior."
)
if self._next_runner_accepted_args:
args, kwargs = self._next_runner_accepted_args
self._next_runner_accepted_args = None
runner.add_accepted_args(self._accept_partial_call, *args, **kwargs)
self._accept_partial_call = False
self._runner = runner
self._callable_mock.runners.insert(0, runner) # type: ignore
def _assert_runner(self) -> None:
if not self._runner:
raise ValueError(
"You must first define a behavior. Eg: "
"self.mock_callable(target, 'func')"
".to_return_value(value)"
".and_assert_called_exactly(times)"
)
if self._runner._call_count > 0:
raise ValueError(
f"No extra configuration is allowed after {self._NAME} "
f"receives its first call, it received {self._runner._call_count} "
f"call{'s' if self._runner._call_count > 1 else ''} already. "
"You should instead define it all at once, "
f"eg: self.{self._NAME}(target, 'func')"
".to_return_value(value).and_assert_called_once()"
)
##
## Arguments
##
def for_call(
self, *args: Any, **kwargs: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Filter for only calls like this.
"""
if self._runner:
self._runner.add_accepted_args(False, *args, **kwargs)
else:
self._next_runner_accepted_args = (args, kwargs)
return self
def for_partial_call(
self, *args: Any, **kwargs: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
if self._runner:
self._runner.add_accepted_args(True, *args, **kwargs)
else:
self._accept_partial_call = True
self._next_runner_accepted_args = (args, kwargs)
return self
##
## Behavior
##
def to_return_value(
self, value: Any
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Always return given value.
"""
self._add_runner(
_ReturnValueRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
value,
self._allow_coro,
)
)
return self
def to_return_values(
self, values_list: List[Any]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
For each call, return each value from given list in order.
When list is exhausted, goes to the next behavior set.
"""
if not isinstance(values_list, list):
raise ValueError("{} is not a list".format(values_list))
self._add_runner(
_ReturnValuesRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
values_list,
self._allow_coro,
)
)
return self
def to_yield_values(
self, values_list: List[Any]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Callable will return an iterator what will yield each value from the
given list.
"""
if not isinstance(values_list, list):
raise ValueError("{} is not a list".format(values_list))
self._add_runner(
_YieldValuesRunner(
self._original_target,
self._method,
self._original_callable, # type: ignore
values_list,
self._allow_coro,
)
)
return self
def to_raise(
self, ex: Union[Type[BaseException], BaseException]
) -> Union["_MockCallableDSL", "_MockAsyncCallableDSL", "_MockConstructorDSL"]:
"""
Raises given exception class or exception instance.
"""
if isinstance(ex, BaseException):
self._add_runner(
_RaiseRunner(
self._original_target, self._method, self._original_callable, ex # type: ignore
)
)
elif isinstance(ex(), BaseException):
self._add_runner(
_RaiseRunner(