Skip to content

Commit

Permalink
--wip-- [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
sauyon committed Apr 20, 2023
1 parent a4c3eac commit 73cae40
Showing 1 changed file with 59 additions and 47 deletions.
106 changes: 59 additions & 47 deletions src/bentoml/_internal/marshal/dispatcher.py
Expand Up @@ -15,6 +15,7 @@

from ..utils import cached_property
from ..utils.alg import TokenBucket
from ...exceptions import BadInput

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, options: dict[str, t.Any]):
pass

@abstractmethod
def log_outbound(self, n: int, wait: float, duration: float):
def log_outbound(self, batch_size: int, duration: float):
pass

@abstractmethod
Expand All @@ -76,7 +77,7 @@ def predict_diff(self, first_batch_size: int, second_batch_size: int) -> float:

def __init_subclass__(cls, optimizer_id: str):
OPTIMIZER_REGISTRY[optimizer_id] = cls
cls.strategy_id = optimizer_id
cls.optimizer_id = optimizer_id


class FixedOptimizer(Optimizer, optimizer_id="fixed"):
Expand All @@ -95,8 +96,8 @@ class LinearOptimizer(Optimizer, optimizer_id="linear"):
"""
Analyze historical data to predict execution time using a simple linear regression on batch size.
"""
o_a: int = 2
o_b: int = 1
o_a: float = 2
o_b: float = 1

n_kept_sample = 50 # amount of outbound info kept for inferring params
n_skipped_sample = 2 # amount of outbound info skipped after init
Expand All @@ -121,25 +122,21 @@ def __init__(self, options: dict[str, t.Any]):
else:
logger.warning("Strategy 'target_latency' ignoring unknown configuration key '{key}'.")

self.o_stat: collections.deque[tuple[int, float, float]] = collections.deque(
self.o_stat: collections.deque[tuple[int, float]] = collections.deque(
maxlen=self.n_kept_sample
) # to store outbound stat data
self.o_a = min(2, max_latency * 2.0 / 30)
self.o_b = min(1, max_latency * 1.0 / 30)

self.wait = 0 # the avg wait time before outbound called

self._refresh_tb = TokenBucket(2) # to limit params refresh interval
self.outbound_counter = 0

def log_outbound(self, n: int, wait: float, duration: float):
def log_outbound(self, batch_size: int, duration: float):
if self.outbound_counter <= self.n_skipped_sample + 4:
self.outbound_counter += 1
# skip inaccurate info at beginning
if self.outbound_counter <= self.n_skipped_sample:
return

self.o_stat.append((n, duration, wait))
self.o_stat.append((batch_size, duration))

if self._refresh_tb.consume(1, 1.0 / self.param_refresh_interval, 1):
self.trigger_refresh()
Expand All @@ -151,20 +148,17 @@ def predict_diff(self, first_batch_size: int, second_batch_size: int):
return self.o_a * (second_batch_size - first_batch_size)

def trigger_refresh(self):
x = tuple((i, 1) for i, _, _ in self.o_stat)
y = tuple(i for _, i, _ in self.o_stat)
x = tuple((i, 1) for i, _ in self.o_stat)
y = tuple(i for _, i in self.o_stat)

_factors: tuple[float, float] = np.linalg.lstsq(x, y, rcond=None)[0] # type: ignore
_factors = t.cast(tuple[float, float], np.linalg.lstsq(x, y, rcond=None)[0])
_o_a, _o_b = _factors
_o_w = sum(w for _, _, w in self.o_stat) * 1.0 / len(self.o_stat)

self.o_a, self.o_b = max(0.000001, _o_a), max(0, _o_b)
self.wait = max(0, _o_w)
logger.debug(
"Dynamic batching optimizer params updated: o_a: %.6f, o_b: %.6f, wait: %.6f",
"Dynamic batching optimizer params updated: o_a: %.6f, o_b: %.6f",
_o_a,
_o_b,
_o_w,
)


Expand All @@ -183,7 +177,7 @@ def __init__(self, optimizer: Optimizer, options: dict[t.Any, t.Any]):
pass

@abstractmethod
def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency: float, max_batch_size: int, tick_interval: float):
async def batch(self, optimizer: Optimizer, queue: t.Deque[Job], max_latency: float, max_batch_size: int, tick_interval: float, dispatch: t.Callable[[t.Sequence[Job]], None]):
pass

def __init_subclass__(cls, strategy_id: str):
Expand All @@ -202,7 +196,8 @@ def __init__(self, options: dict[t.Any, t.Any]):
logger.warning("Strategy 'target_latency' ignoring unknown configuration key '{key}'.")


async def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency: float, max_batch_size: int, tick_interval: float):
async def batch(self, optimizer: Optimizer, queue: t.Deque[Job], max_latency: float, max_batch_size: int, tick_interval: float, dispatch: t.Callable[[t.Sequence[Job]], None]):
n = len(queue)
now = time.time()
w0 = now - queue[0].enqueue_time
latency_0 = w0 + optimizer.predict(n)
Expand All @@ -219,14 +214,20 @@ async def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency:
class AdaptiveStrategy(BatchingStrategy, strategy_id="adaptive"):
decay: float = 0.95

n_kept_samples = 50
avg_wait_times: collections.deque[float]
avg_req_wait: float = 0

def __init__(self, options: dict[t.Any, t.Any]):
for key in options:
if key == "decay":
self.decay = options[key]
else:
logger.warning("Strategy 'adaptive' ignoring unknown configuration value")

async def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency: float, max_batch_size: int, tick_interval: float):
self.avg_wait_times = collections.deque(maxlen=self.n_kept_samples)

async def batch(self, optimizer: Optimizer, queue: t.Deque[Job], max_latency: float, max_batch_size: int, tick_interval: float, dispatch: t.Callable[[t.Sequence[Job]], None]):
n = len(queue)
now = time.time()
w0 = now - queue[0].enqueue_time
Expand All @@ -238,7 +239,7 @@ async def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency:
# we are not about to cancel the first request,
and latency_0 + tick_interval <= max_latency * 0.95
# and waiting will cause average latency to decrese
and n * (wn + tick_interval + optimizer.predict_diff(n, n+1)) <= optimizer.wait * self.decay
and n * (wn + tick_interval + optimizer.predict_diff(n, n+1)) <= self.avg_req_wait * self.decay
):
n = len(queue)
now = time.time()
Expand All @@ -248,9 +249,28 @@ async def wait(self, optimizer: Optimizer, queue: t.Sequence[Job], max_latency:
# wait for additional requests to arrive
await asyncio.sleep(tick_interval)

# dispatch the batch
n = len(queue)
n_call_out = min(max_batch_size, n)
# call
inputs_info: list[Job] = []
for _ in range(n_call_out):
job = queue.pop()
new_wait = (now - job.enqueue_time) / self.n_kept_samples
if len(self.avg_wait_times) == self.n_kept_samples:
oldest_wait = self.avg_wait_times.popleft()
self.avg_req_wait = self.avg_req_wait - oldest_wait + new_wait
else:
self.avg_req_wait += new_wait
self.avg_wait_times.append(new_wait)
inputs_info.append(job)

dispatch(inputs_info)


class Dispatcher:
background_tasks: set[asyncio.Task[None]] = set()

def __init__(
self,
max_latency_in_ms: int,
Expand Down Expand Up @@ -283,6 +303,8 @@ def __init__(
def shutdown(self):
if self._controller is not None:
self._controller.cancel()
for task in self.background_tasks:
task.cancel()
try:
while True:
fut = self._queue.pop().future
Expand Down Expand Up @@ -329,10 +351,11 @@ async def train_optimizer(
if self.max_batch_size < batch_size:
batch_size = self.max_batch_size

wait = 0
if batch_size > 1:
wait = min(
self.max_latency * 0.95,
(batch_size * 2 + 1) * (self.optimizer.o_a + self.optimizer.o_b),
self.optimizer.predict(batch_size * 2 + 1),
)

req_count = 0
Expand All @@ -351,10 +374,7 @@ async def train_optimizer(
self._queue.popleft().future.cancel()
continue
if batch_size > 1: # only wait if batch_size
a = self.optimizer.o_a
b = self.optimizer.o_b

if n < batch_size and (batch_size * a + b) + w0 <= wait:
if n < batch_size and self.optimizer.predict(batch_size) + w0 <= wait:
await asyncio.sleep(self.tick_interval)
continue
if self._sema.is_locked():
Expand Down Expand Up @@ -400,7 +420,7 @@ async def controller(self):
self.optimizer.trigger_refresh()
logger.debug("Dispatcher finished optimizer training request 3.")

if self.optimizer.o_a + self.optimizer.o_b >= self.max_latency:
if self.optimizer.predict(1) >= self.max_latency:
logger.warning(
"BentoML has detected that a service has a max latency that is likely too low for serving. If many 503 errors are encountered, try raising the 'runner.max_latency' in your BentoML configuration YAML file."
)
Expand All @@ -414,16 +434,11 @@ async def controller(self):
await self._wake_event.wait_for(self._queue.__len__)

n = len(self._queue)
dt = self.tick_interval
decay = 0.95 # the decay rate of wait time
now = time.time()
w0 = now - self._queue[0].enqueue_time
wn = now - self._queue[-1].enqueue_time
a = self.optimizer.o_a
b = self.optimizer.o_b

# the estimated latency of the first request if we began processing now
latency_0 = w0 + a * n + b
latency_0 = w0 + self.optimizer.predict(n)

if n > 1 and latency_0 >= self.max_latency:
self._queue.popleft().future.cancel()
Expand All @@ -436,20 +451,18 @@ async def controller(self):
continue

# we are now free to dispatch whenever we like
await self.strategy.wait(self.optimizer, self._queue, self.max_latency, self.max_batch_size, self.tick_interval)

n = len(self._queue)

n_call_out = min(self.max_batch_size, n)
# call
self._sema.acquire()
inputs_info = tuple(self._queue.pop() for _ in range(n_call_out))
self._loop.create_task(self.outbound_call(inputs_info))
await self.strategy.batch(self.optimizer, self._queue, self.max_latency, self.max_batch_size, self.tick_interval, self._dispatch)
except asyncio.CancelledError:
return
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc(), exc_info=e)

def _dispatch(self, inputs_info: t.Sequence[Job]):
self._sema.acquire()
task = self._loop.create_task(self.outbound_call(inputs_info))
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)

async def inbound_call(self, data: t.Any):
now = time.time()
future = self._loop.create_future()
Expand All @@ -459,7 +472,7 @@ async def inbound_call(self, data: t.Any):
self._wake_event.notify_all()
return await future

async def outbound_call(self, inputs_info: tuple[Job, ...]):
async def outbound_call(self, inputs_info: t.Sequence[Job]):
_time_start = time.time()
_done = False
batch_size = len(inputs_info)
Expand All @@ -474,9 +487,8 @@ async def outbound_call(self, inputs_info: tuple[Job, ...]):
if not fut.done():
fut.set_result(out)
_done = True
self.strategy.log_outbound(
n=len(inputs_info),
wait=_time_start - inputs_info[-1].enqueue_time,
self.optimizer.log_outbound(
batch_size=len(inputs_info),
duration=time.time() - _time_start,
)
except asyncio.CancelledError:
Expand Down

0 comments on commit 73cae40

Please sign in to comment.