Skip to content

Commit

Permalink
Add SIGINT handling (#104)
Browse files Browse the repository at this point in the history
* Refactor loop runner

* Add task cancelling

* Handle outputs from tasks properly

* Gather in the right order

* Explicit is better than implicit

* No signal handler for you windows

* Cancel writer before remove or windows gets sad

* Cancel FTP properly
  • Loading branch information
Cadair committed Jul 1, 2022
1 parent ceaa210 commit 1e802fe
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 46 deletions.
134 changes: 91 additions & 43 deletions parfive/downloader.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import os
import signal
import asyncio
import logging
import pathlib
import warnings
import contextlib
import urllib.parse
from typing import Dict, Union, Callable, Optional
from typing import Union, Callable, Optional
from functools import reduce

try:
from typing import Literal # Added in Python 3.8
except ImportError:
from typing_extensions import Literal # type: ignore

from functools import partial
from concurrent.futures import ThreadPoolExecutor

import aiohttp
from tqdm import tqdm as tqdm_std
Expand All @@ -25,7 +25,6 @@
from .utils import (
FailedDownload,
MultiPartDownloadError,
ParfiveFutureWarning,
Token,
_QueueList,
cancel_task,
Expand All @@ -34,7 +33,7 @@
get_ftp_size,
get_http_size,
remove_file,
run_in_thread,
run_task_in_thread,
)

try:
Expand Down Expand Up @@ -220,23 +219,40 @@ def filepath(url, resp):
raise ValueError("URL must start with either 'http' or 'ftp'.")

@staticmethod
def _run_in_loop(coro):
"""
Detect an existing, running loop and run in a separate loop if needed.
def _add_shutdown_signals(loop, task):
if os.name == "nt":
return
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, task.cancel)

If no loop is running, use asyncio.run to run the coroutine instead.
def _run_in_loop(self, coro):
"""
Take a coroutine and figure out where to run it and how to cancel it.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
aio_pool = ThreadPoolExecutor(1)
new_loop = asyncio.new_event_loop()
return run_in_thread(aio_pool, new_loop, coro)
# If we already have a loop and it's already running then we should
# make a new loop (as we are probably in a Jupyter Notebook)
should_run_in_thread = loop and loop.is_running()

# If we don't already have a loop, make a new one
if should_run_in_thread or loop is None:
loop = asyncio.new_event_loop()

# Wrap up the coroutine in a task so we can cancel it later
task = loop.create_task(coro)

return asyncio.run(coro)
# Add handlers for shutdown signals
self._add_shutdown_signals(loop, task)

# Execute the task
if should_run_in_thread:
return run_task_in_thread(loop, task)

return loop.run_until_complete(task)

async def run_download(self):
"""
Expand All @@ -248,18 +264,31 @@ async def run_download(self):
A list of files downloaded.
"""
total_files = self.queued_downloads

done = set()
with self._get_main_pb(total_files) as main_pb:
if len(self.http_queue):
done.update(await self._run_http_download(main_pb))
if len(self.ftp_queue):
done.update(await self._run_ftp_download(main_pb))

dl_results = await asyncio.gather(*done, return_exceptions=True)
errors = sum([isinstance(i, FailedDownload) for i in dl_results])
tasks = set()
with self._get_main_pb(self.queued_downloads) as main_pb:
try:
if len(self.http_queue):
tasks.add(asyncio.create_task(self._run_http_download(main_pb)))
if len(self.ftp_queue):
tasks.add(asyncio.create_task(self._run_ftp_download(main_pb)))
dl_results = await asyncio.gather(*tasks, return_exceptions=True)

except asyncio.CancelledError:
for task in tasks:
task.cancel()
dl_results = await asyncio.gather(*tasks, return_exceptions=True)

finally:
results_obj = self._format_results(dl_results, main_pb)
return results_obj

def _format_results(self, retvals, main_pb):
# Squash all nested lists into a single flat list
if retvals:
retvals = list(reduce(list.__add__, retvals))
errors = sum([isinstance(i, FailedDownload) for i in retvals])
if errors:
total_files = self.queued_downloads
message = f"{errors}/{total_files} files failed to download. Please check `.errors` for details"
if main_pb:
main_pb.write(message)
Expand All @@ -270,7 +299,7 @@ async def run_download(self):

# Iterate through the results and store any failed download errors in
# the errors list of the results object.
for res in dl_results:
for res in retvals:
if isinstance(res, FailedDownload):
results.add_error(res.filepath_partial, res.url, res.exception)
parfive.log.info(
Expand Down Expand Up @@ -372,29 +401,37 @@ def _get_main_pb(self, total):

async def _run_http_download(self, main_pb):
async with self.config.aiohttp_client_session() as session:
self._generate_tokens()
futures = await self._run_from_queue(
self.http_queue.generate_queue(),
self._generate_tokens(),
main_pb,
session=session,
)

# Wait for all the coroutines to finish
done, _ = await asyncio.wait(futures)
try:
# Wait for all the coroutines to finish
done, _ = await asyncio.wait(futures)
except asyncio.CancelledError:
for task in futures:
task.cancel()

return done
return await asyncio.gather(*futures, return_exceptions=True)

async def _run_ftp_download(self, main_pb):
futures = await self._run_from_queue(
self.ftp_queue.generate_queue(),
self._generate_tokens(),
main_pb,
)
# Wait for all the coroutines to finish
done, _ = await asyncio.wait(futures)

return done
try:
# Wait for all the coroutines to finish
done, _ = await asyncio.wait(futures)
except asyncio.CancelledError:
for task in futures:
task.cancel()

return await asyncio.gather(*futures, return_exceptions=True)

async def _run_from_queue(self, queue, tokens, main_pb, *, session=None):
futures = []
Expand All @@ -405,10 +442,13 @@ async def _run_from_queue(self, queue, tokens, main_pb, *, session=None):
future = asyncio.create_task(get_file(session, token=token, file_pb=file_pb))

def callback(token, future, main_pb):
tokens.put_nowait(token)
# Update the main progressbar
if main_pb and not future.exception():
main_pb.update(1)
try:
tokens.put_nowait(token)
# Update the main progressbar
if main_pb and not future.exception():
main_pb.update(1)
except asyncio.CancelledError:
return

future.add_done_callback(partial(callback, token, main_pb=main_pb))
futures.append(future)
Expand Down Expand Up @@ -465,6 +505,7 @@ async def _get_http(

# Define filepath and writer here as we use them in the except block
filepath = writer = None
tasks = []
try:
scheme = urllib.parse.urlparse(url).scheme
if scheme == "http":
Expand Down Expand Up @@ -511,7 +552,6 @@ async def _get_http(
# as tuples: (offset, chunk)
downloaded_chunk_queue = asyncio.Queue()

download_workers = []
writer = asyncio.create_task(
self._write_worker(downloaded_chunk_queue, file_pb, filepath)
)
Expand All @@ -531,7 +571,7 @@ async def _get_http(
# let the last part download everything
ranges[-1][1] = ""
for _range in ranges:
download_workers.append(
tasks.append(
asyncio.create_task(
self._http_download_worker(
session,
Expand All @@ -544,7 +584,7 @@ async def _get_http(
)
)
else:
download_workers.append(
tasks.append(
asyncio.create_task(
self._http_download_worker(
session,
Expand All @@ -559,19 +599,26 @@ async def _get_http(
# Close the initial request here before we start transferring data.

# run all the download workers
await asyncio.gather(*download_workers)
await asyncio.gather(*tasks)
# join() waits till all the items in the queue have been processed
await downloaded_chunk_queue.join()
return str(filepath)

except Exception as e:
except (Exception, asyncio.CancelledError) as e:
for task in tasks:
task.cancel()
# We have to cancel the writer here before we try and remove the
# file so it's closed (otherwise windows gets angry).
if writer is not None:
await cancel_task(writer)
# Set writer to None so we don't cancel it twice.
writer = None
# If filepath is None then the exception occurred before the request
# computed the filepath, so we have no file to cleanup
if filepath is not None:
remove_file(filepath)
raise FailedDownload(filepath_partial, url, e)

finally:
if writer is not None:
writer.cancel()
Expand Down Expand Up @@ -776,9 +823,10 @@ async def _get_ftp(
await downloaded_chunks_queue.join()
return str(filepath)

except Exception as e:
except (Exception, asyncio.CancelledError) as e:
if writer is not None:
await cancel_task(writer)
writer = None
# If filepath is None then the exception occurred before the request
# computed the filepath, so we have no file to cleanup
if filepath is not None:
Expand Down
3 changes: 2 additions & 1 deletion parfive/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def run_parfive(args):
for err in results.errors:
err_str += f"{err.url} \t {err.exception}\n"
if err_str:
sys.exit(err_str)
print(err_str, file=sys.stderr)
sys.exit(1)

sys.exit(0)

Expand Down
10 changes: 8 additions & 2 deletions parfive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from pathlib import Path
from itertools import count
from concurrent.futures import ThreadPoolExecutor

import aiohttp

Expand Down Expand Up @@ -35,15 +36,20 @@ def default_name(path: os.PathLike, resp: aiohttp.ClientResponse, url: str) -> o
return pathlib.Path(path) / name


def run_in_thread(aio_pool, loop, coro):
def run_task_in_thread(loop, coro):
"""
This function returns the asyncio Future after running the loop in a
thread.
This makes the return value of this function the same as the return
of ``loop.run_until_complete``.
"""
return aio_pool.submit(loop.run_until_complete, coro).result()
with ThreadPoolExecutor(max_workers=1) as aio_pool:
try:
future = aio_pool.submit(loop.run_until_complete, coro)
except KeyboardInterrupt:
future.cancel()
return future.result()


async def get_ftp_size(client, filepath):
Expand Down

0 comments on commit 1e802fe

Please sign in to comment.