Skip to content

Commit

Permalink
Ensure the pool stops when encountering exception
Browse files Browse the repository at this point in the history
  • Loading branch information
CarlGao4 committed Nov 14, 2023
1 parent 9608782 commit 07845f4
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions demucs/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ def apply_model(model: tp.Union[BagOfModels, Model],
if progress:
futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
for future, offset in futures:
chunk_out = future.result() # type: th.Tensor
try:
chunk_out = future.result() # type: th.Tensor
except Exception:
pool.shutdown(wait=True, cancel_futures=True)
raise
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment_length] += (
weight[:chunk_length] * chunk_out).to(mix.device)
Expand All @@ -308,19 +312,11 @@ def apply_model(model: tp.Union[BagOfModels, Model],
padded_mix = mix.padded(valid_length).to(device)
with lock:
if callback is not None:
try:
callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
except Exception:
pool.shutdown(wait=True, cancel_futures=True)
raise
callback(_replace_dict(callback_arg, ("state", "start"))) # type: ignore
with th.no_grad():
out = model(padded_mix)
with lock:
if callback is not None:
try:
callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
except Exception:
pool.shutdown(wait=True, cancel_futures=True)
raise
callback(_replace_dict(callback_arg, ("state", "end"))) # type: ignore
assert isinstance(out, th.Tensor)
return center_trim(out, length)

0 comments on commit 07845f4

Please sign in to comment.