diff --git a/jupyter_archive/handlers.py b/jupyter_archive/handlers.py index 35e3064..4401959 100644 --- a/jupyter_archive/handlers.py +++ b/jupyter_archive/handlers.py @@ -72,15 +72,15 @@ def make_writer(handler, archive_format="zip"): def make_reader(archive_path): - archive_format = "".join(archive_path.suffixes)[1:] + archive_format = "".join(archive_path.suffixes) - if archive_format.endswith("zip"): + if archive_format.endswith(".zip"): archive_file = zipfile.ZipFile(archive_path, mode="r") - elif any([archive_format.endswith(ext) for ext in ["tgz", "tar.gz"]]): + elif any([archive_format.endswith(ext) for ext in [".tgz", ".tar.gz"]]): archive_file = tarfile.open(archive_path, mode="r|gz") - elif any([archive_format.endswith(ext) for ext in ["tbz", "tbz2", "tar.bz", "tar.bz2"]]): + elif any([archive_format.endswith(ext) for ext in [".tbz", ".tbz2", ".tar.bz", ".tar.bz2"]]): archive_file = tarfile.open(archive_path, mode="r|bz2") - elif any([archive_format.endswith(ext) for ext in ["txz", "tar.xz"]]): + elif any([archive_format.endswith(ext) for ext in [".txz", ".tar.xz"]]): archive_file = tarfile.open(archive_path, mode="r|xz") else: raise ValueError("'{}' is not a valid archive format.".format(archive_format)) @@ -141,8 +141,7 @@ async def get(self, archive_path, include_body=False): raise web.HTTPError(400) archive_path = pathlib.Path(cm.root_dir) / url2path(archive_path) - archive_name = archive_path.name - archive_filename = archive_path.with_suffix(".{}".format(archive_format)).name + archive_filename = f"{archive_path.name}.{archive_format}" self.log.info("Prepare {} for archiving and downloading.".format(archive_filename)) self.set_header("content-type", "application/octet-stream") diff --git a/jupyter_archive/tests/test_archive_handler.py b/jupyter_archive/tests/test_archive_handler.py index 4eabf5b..77dac82 100644 --- a/jupyter_archive/tests/test_archive_handler.py +++ b/jupyter_archive/tests/test_archive_handler.py @@ -5,6 +5,8 @@ import pytest +from tornado.httpclient import HTTPClientError + @pytest.mark.parametrize( "followSymlinks, download_hidden, file_list", @@ -111,6 +113,36 @@ async def test_download(jp_fetch, jp_root_dir, followSymlinks, download_hidden, assert set(map(lambda m: m.name, tf.getmembers())) == file_list +def _create_archive_file(root_dir, file_name, format, mode): + # Create a dummy directory. + archive_dir_path = root_dir / file_name + archive_dir_path.mkdir(parents=True) + + (archive_dir_path / "extract-test1.txt").write_text("hello1") + (archive_dir_path / "extract-test2.txt").write_text("hello2") + (archive_dir_path / "extract-test3.md").write_text("hello3") + + # Make an archive + archive_dir_path = root_dir / file_name + # The request should fail when the extension has an unnecessary prefix. + archive_path = archive_dir_path.parent / f"{archive_dir_path.name}.{format}" + if format == "zip": + with zipfile.ZipFile(archive_path, mode=mode) as writer: + for file_path in archive_dir_path.rglob("*"): + if file_path.is_file(): + writer.write(file_path, file_path.relative_to(root_dir)) + else: + with tarfile.open(str(archive_path), mode=mode) as writer: + for file_path in archive_dir_path.rglob("*"): + if file_path.is_file(): + writer.add(file_path, file_path.relative_to(root_dir)) + + # Remove the directory + shutil.rmtree(archive_dir_path) + + return archive_dir_path, archive_path + + @pytest.mark.parametrize( "file_name", [ @@ -134,30 +166,7 @@ async def test_download(jp_fetch, jp_root_dir, followSymlinks, download_hidden, ], ) async def test_extract(jp_fetch, jp_root_dir, file_name, format, mode): - # Create a dummy directory. - archive_dir_path = jp_root_dir / file_name - archive_dir_path.mkdir(parents=True) - - (archive_dir_path / "extract-test1.txt").write_text("hello1") - (archive_dir_path / "extract-test2.txt").write_text("hello2") - (archive_dir_path / "extract-test3.md").write_text("hello3") - - # Make an archive - archive_dir_path = jp_root_dir / file_name - archive_path = archive_dir_path.with_suffix("." + format) - if format == "zip": - with zipfile.ZipFile(archive_path, mode=mode) as writer: - for file_path in archive_dir_path.rglob("*"): - if file_path.is_file(): - writer.write(file_path, file_path.relative_to(jp_root_dir)) - else: - with tarfile.open(str(archive_path), mode=mode) as writer: - for file_path in archive_dir_path.rglob("*"): - if file_path.is_file(): - writer.add(file_path, file_path.relative_to(jp_root_dir)) - - # Remove the directory - shutil.rmtree(archive_dir_path) + archive_dir_path, archive_path = _create_archive_file(jp_root_dir, file_name, format, mode) r = await jp_fetch("extract-archive", archive_path.relative_to(jp_root_dir).as_posix(), method="GET") assert r.code == 200 @@ -165,3 +174,28 @@ async def test_extract(jp_fetch, jp_root_dir, file_name, format, mode): n_files = len(list(archive_dir_path.glob("*"))) assert n_files == 3 + + +@pytest.mark.parametrize( + "format, mode", + [ + ("zip", "w"), + ("tgz", "w|gz"), + ("tar.gz", "w|gz"), + ("tbz", "w|bz2"), + ("tbz2", "w|bz2"), + ("tar.bz", "w|bz2"), + ("tar.bz2", "w|bz2"), + ("txz", "w|xz"), + ("tar.xz", "w|xz"), + ], +) +async def test_extract_failure(jp_fetch, jp_root_dir, format, mode): + # The request should fail when the extension has an unnecessary prefix. + prefixed_format = f"prefix{format}" + archive_dir_path, archive_path = _create_archive_file(jp_root_dir, "extract-archive-dir", prefixed_format, mode) + + with pytest.raises(Exception) as e: + await jp_fetch("extract-archive", archive_path.relative_to(jp_root_dir).as_posix(), method="GET") + assert e.type == HTTPClientError + assert not archive_dir_path.exists()