/
patch.py
112 lines (91 loc) · 3.89 KB
/
patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import inspect
from typing import Any, Callable, Dict, Optional, Union
class _DescriptorProxy:
def __init__(
self,
original_class_attr: Optional[Union[Callable, "_DescriptorProxy"]],
attr_name: str,
) -> None:
self.original_class_attr = original_class_attr
self.attr_name = attr_name
self.instance_attr_map: Dict[int, Callable] = {}
def __set__(self, instance: object, value: Callable) -> None:
self.instance_attr_map[id(instance)] = value
def __get__(
self, instance: object, owner: object
) -> Union[Callable, "_DescriptorProxy"]:
if instance is None:
return self
if id(instance) in self.instance_attr_map:
return self.instance_attr_map[id(instance)]
if self.original_class_attr:
return self.original_class_attr.__get__(instance, owner) # type: ignore
for parent in owner.mro()[1:]: # type: ignore
method = parent.__dict__.get(self.attr_name, None)
if type(method) is type(self):
continue
if method:
return method.__get__(instance, owner)
return instance.__get__(instance, owner) # type: ignore
def __delete__(self, instance: object) -> None:
if id(instance) in self.instance_attr_map:
del self.instance_attr_map[id(instance)]
def _is_instance_method(target: Any, method: str) -> bool:
if inspect.ismodule(target):
return False
klass = target if inspect.isclass(target) else type(target)
for k in klass.mro():
if method in k.__dict__:
value = k.__dict__[method]
if isinstance(value, _DescriptorProxy):
while isinstance(value, _DescriptorProxy):
value = value.original_class_attr
if inspect.isfunction(value):
return True
return False
def _mock_instance_attribute(instance: Any, attr: str, value: Any) -> Callable:
"""
Patch attribute at instance with given value. This works for any instance
attribute, even when the attribute is defined via the descriptor protocol using
__get__ at the class (eg with @property).
This allows mocking of the attribute only at the desired instance, as opposed to
using Python's unittest.mock.patch.object + PropertyMock, that requires patching
at the class level, thus affecting all instances (not only the one you want).
"""
klass = type(instance)
class_restore_value = klass.__dict__.get(attr, None)
setattr(klass, attr, _DescriptorProxy(class_restore_value, attr))
setattr(instance, attr, value)
def unpatch_class() -> None:
if class_restore_value:
setattr(klass, attr, class_restore_value)
else:
delattr(klass, attr)
return unpatch_class
def _patch(
target: Any, attribute: str, new_value: Any, restore: Any, restore_value: Any = None
) -> Callable:
if _is_instance_method(target, attribute):
unpatcher = _mock_instance_attribute(target, attribute, new_value)
elif hasattr(type(target), attribute) and isinstance(
getattr(type(target), attribute), property
):
original_property = getattr(type(target), attribute)
setattr(type(target), attribute, property(fget=lambda _: new_value))
def unpatcher() -> None:
if restore_value:
setattr(type(target), attribute, original_property)
else:
delattr(target, attribute)
else:
setattr(target, attribute, new_value)
def unpatcher() -> None:
if restore_value:
setattr(target, attribute, restore_value)
else:
delattr(target, attribute)
return unpatcher