-
Notifications
You must be signed in to change notification settings - Fork 47
/
contexts.py
151 lines (110 loc) · 4.55 KB
/
contexts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright 2016 Quora, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asynq
import asyncio
from .asynq_to_async import is_asyncio_mode
from . import debug
from ._debug import options as _debug_options
ASYNCIO_CONTEXT_FIELD = "_asynq_contexts"
ASYNCIO_CONTEXT_ACTIVE_FIELD = "_asynq_contexts_active"
class NonAsyncContext(object):
"""Indicates that context can't contain yield statements.
It means while a NonAsyncContext is active, async tasks cannot
yield the control back to scheduler.
Contexts should subclass this class if they want it throw an AssertionError
if async tasks end up yielding within them.
Note: Remember to call super inside your __enter__/__exit__.
"""
def __enter__(self):
if is_asyncio_mode():
enter_context_asyncio(self)
else:
self._active_task = enter_context(self)
def __exit__(self, typ, val, tb):
if is_asyncio_mode():
leave_context_asyncio(self)
else:
leave_context(self, self._active_task)
def pause(self):
assert False, "Task %s cannot yield while %s is active" % (
self._active_task,
self,
)
def resume(self):
assert False, "Task %s cannot yield while %s is active" % (
self._active_task,
self,
)
def enter_context(context):
# perf optimization: inline get_active_task
active_task = asynq.scheduler._state.current.active_task
if active_task is not None:
active_task._enter_context(context)
return active_task
def leave_context(context, active_task):
if active_task is not None:
active_task._leave_context(context)
def enter_context_asyncio(context):
if _debug_options.DUMP_CONTEXTS:
debug.write("@async: +context: %s" % debug.str(context))
# since we are in asyncio mode, there is an active task
task = asyncio.current_task()
if hasattr(task, ASYNCIO_CONTEXT_FIELD):
getattr(task, ASYNCIO_CONTEXT_FIELD)[id(context)] = context
else:
setattr(task, ASYNCIO_CONTEXT_FIELD, {id(context): context})
def leave_context_asyncio(context):
if _debug_options.DUMP_CONTEXTS:
debug.write("@async: -context: %s" % debug.str(context))
task = asyncio.current_task()
getattr(task, ASYNCIO_CONTEXT_FIELD, {}).pop(id(context), None) # type: ignore
def pause_contexts_asyncio(task):
if getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False):
setattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, False)
for ctx in reversed(list(getattr(task, ASYNCIO_CONTEXT_FIELD, {}).values())):
ctx.pause()
def resume_contexts_asyncio(task):
if not getattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, True):
setattr(task, ASYNCIO_CONTEXT_ACTIVE_FIELD, True)
for ctx in getattr(task, ASYNCIO_CONTEXT_FIELD, {}).values():
ctx.resume()
class AsyncContext(object):
"""Base class for contexts that should pause and resume during an async's function execution.
Your context should subclass this class and implement pause and resume (at least).
That would make the context pause and resume each time the execution of the async function
within this context is paused and resumed.
Additionally, you can also subclass __enter__ and __exit__ if you want to customize its
behaviour. Remember to call super in that case.
NOTE: __enter__/__exit__ methods automatically call resume/pause so the overridden
__enter__/__exit__ methods shouldn't do that explicitly.
"""
def __enter__(self):
if is_asyncio_mode():
enter_context_asyncio(self)
else:
self._active_task = enter_context(self)
self.resume()
return self
def __exit__(self, ty, value, tb):
if is_asyncio_mode():
leave_context_asyncio(self)
self.pause()
else:
leave_context(self, self._active_task)
self.pause()
del self._active_task
def resume(self):
raise NotImplementedError()
def pause(self):
raise NotImplementedError()