Skip to content

Commit

Permalink
Fix open kwarg passthrough (#204)
Browse files Browse the repository at this point in the history
* upath.core: handle kwargs in UPath.open
* upath: tests for open
* upath.local: support fsspec options in open
  • Loading branch information
ap-- committed Mar 3, 2024
1 parent 3646cd2 commit 041aca1
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .flake8
Expand Up @@ -9,6 +9,8 @@ ignore=
# unindexed parameters in the str.format, see:
# https://pypi.org/project/flake8-string-format/
P1
# def statements on the same line with overload
E704
max_line_length = 88
max-complexity = 15
select = B,C,E,F,W,T4,B902,T,P
Expand Down
69 changes: 65 additions & 4 deletions upath/core.py
Expand Up @@ -6,14 +6,19 @@
from copy import copy
from pathlib import Path
from types import MappingProxyType
from typing import IO
from typing import TYPE_CHECKING
from typing import Any
from typing import BinaryIO
from typing import Literal
from typing import Mapping
from typing import TextIO
from typing import TypeVar
from typing import overload
from urllib.parse import urlsplit

from fsspec import AbstractFileSystem
from fsspec import get_filesystem_class
from fsspec.registry import get_filesystem_class
from fsspec.spec import AbstractFileSystem

from upath._compat import FSSpecAccessorShim
from upath._compat import PathlibPathShim
Expand Down Expand Up @@ -741,8 +746,64 @@ def is_socket(self):
def samefile(self, other_path):
raise NotImplementedError

def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None):
return self.fs.open(self.path, mode) # fixme
@overload
def open(
self,
mode: Literal["r", "w", "a"] = ...,
buffering: int = ...,
encoding: str = ...,
errors: str = ...,
newline: str = ...,
**fsspec_kwargs: Any,
) -> TextIO: ...

@overload
def open(
self,
mode: Literal["rb", "wb", "ab"] = ...,
buffering: int = ...,
encoding: str = ...,
errors: str = ...,
newline: str = ...,
**fsspec_kwargs: Any,
) -> BinaryIO: ...

def open(
self,
mode: str = "r",
*args: Any,
**fsspec_kwargs: Any,
) -> IO[Any]:
"""
Open the file pointed by this path and return a file object, as
the built-in open() function does.
Parameters
----------
mode:
Opening mode. Default is 'r'.
buffering:
Default is the block size of the underlying fsspec filesystem.
encoding:
Encoding is only used in text mode. Default is None.
errors:
Error handling for encoding. Only used in text mode. Default is None.
newline:
Newline handling. Only used in text mode. Default is None.
**fsspec_kwargs:
Additional options for the fsspec filesystem.
"""
# match the signature of pathlib.Path.open()
for key, value in zip(["buffering", "encoding", "errors", "newline"], args):
if key in fsspec_kwargs:
raise TypeError(
f"{type(self).__name__}.open() got multiple values for '{key}'"
)
fsspec_kwargs[key] = value
# translate pathlib buffering to fs block_size
if "buffering" in fsspec_kwargs:
fsspec_kwargs.setdefault("block_size", fsspec_kwargs.pop("buffering"))
return self.fs.open(self.path, mode=mode, **fsspec_kwargs)

def iterdir(self):
for name in self.fs.listdir(self.path):
Expand Down
43 changes: 43 additions & 0 deletions upath/implementations/local.py
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from pathlib import PosixPath
from pathlib import WindowsPath
from typing import IO
from typing import Any
from typing import Collection
from typing import MutableMapping
Expand Down Expand Up @@ -110,6 +111,27 @@ class PosixUPath(PosixPath, LocalPath):
# assign all PosixPath methods/attrs to prevent multi inheritance issues
_set_class_attributes(locals(), src=PosixPath)

def open(
self,
mode="r",
buffering=-1,
encoding=None,
errors=None,
newline=None,
**fsspec_kwargs,
) -> IO[Any]:
if fsspec_kwargs:
return super(LocalPath, self).open(
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
**fsspec_kwargs,
)
else:
return PosixPath.open(self, mode, buffering, encoding, errors, newline)

if sys.version_info < (3, 12):

def __new__(
Expand Down Expand Up @@ -153,6 +175,27 @@ class WindowsUPath(WindowsPath, LocalPath):
# assign all WindowsPath methods/attrs to prevent multi inheritance issues
_set_class_attributes(locals(), src=WindowsPath)

def open(
self,
mode="r",
buffering=-1,
encoding=None,
errors=None,
newline=None,
**fsspec_kwargs,
) -> IO[Any]:
if fsspec_kwargs:
return super(LocalPath, self).open(
mode=mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
**fsspec_kwargs,
)
else:
return WindowsPath.open(self, mode, buffering, encoding, errors, newline)

if sys.version_info < (3, 12):

def __new__(
Expand Down
20 changes: 19 additions & 1 deletion upath/tests/cases.py
Expand Up @@ -236,7 +236,25 @@ def test_makedirs_exist_ok_false(self):
new_dir._accessor.makedirs(new_dir, exist_ok=False)

def test_open(self):
pass
p = self.path.joinpath("file1.txt")
with p.open(mode="r") as f:
assert f.read() == "hello world"
with p.open(mode="rb") as f:
assert f.read() == b"hello world"

def test_open_buffering(self):
p = self.path.joinpath("file1.txt")
p.open(buffering=-1)

def test_open_block_size(self):
p = self.path.joinpath("file1.txt")
with p.open(mode="r", block_size=8192) as f:
assert f.read() == "hello world"

def test_open_errors(self):
p = self.path.joinpath("file1.txt")
with p.open(mode="r", encoding="ascii", errors="strict") as f:
assert f.read() == "hello world"

def test_owner(self):
with pytest.raises(NotImplementedError):
Expand Down
20 changes: 20 additions & 0 deletions upath/tests/implementations/test_data.py
Expand Up @@ -92,6 +92,26 @@ def test_mkdir_parents_true_exists_ok_true(self):
def test_mkdir_parents_true_exists_ok_false(self):
pass

def test_open(self):
p = UPath("data:text/plain;base64,aGVsbG8gd29ybGQ=")
with p.open(mode="r") as f:
assert f.read() == "hello world"
with p.open(mode="rb") as f:
assert f.read() == b"hello world"

def test_open_buffering(self):
self.path.open(buffering=-1)

def test_open_block_size(self):
p = UPath("data:text/plain;base64,aGVsbG8gd29ybGQ=")
with p.open(mode="r", block_size=8192) as f:
assert f.read() == "hello world"

def test_open_errors(self):
p = UPath("data:text/plain;base64,aGVsbG8gd29ybGQ=")
with p.open(mode="r", encoding="ascii", errors="strict") as f:
assert f.read() == "hello world"

def test_read_bytes(self, pathlib_base):
assert len(self.path.read_bytes()) == 69

Expand Down

0 comments on commit 041aca1

Please sign in to comment.