From eab9fd271bb9d323a403f94967987def22e9cb31 Mon Sep 17 00:00:00 2001 From: "Daniel Jahn (dahn)" Date: Sat, 6 Apr 2024 10:48:13 +0200 Subject: [PATCH] Resolve Mypy erorrs in `v3` branch (#1692) * refactor(v3): Using appropriate types * fix(v3): Typing fixes + minor code fixes * fix(v3): _sync_iter works with coroutines * docs(v3/store/core.py): clearer comment * fix(metadata.py): Use Any outside TYPE_CHECKING for Pydantic * fix(zarr/v3): correct zarr format + remove unused method * fix(v3/store/core.py): Potential suggestion on handling str store_like * refactor(zarr/v3): Add more typing * ci(.pre-commit-config.yaml): zarr v3 mypy checks turned on in pre-commit --- .pre-commit-config.yaml | 1 - src/zarr/v3/abc/metadata.py | 3 +- src/zarr/v3/array.py | 2 +- src/zarr/v3/chunk_grids.py | 2 +- src/zarr/v3/chunk_key_encodings.py | 6 ++-- src/zarr/v3/codecs/transpose.py | 8 +++--- src/zarr/v3/group.py | 44 ++++++++++++++++++------------ src/zarr/v3/metadata.py | 6 ++-- src/zarr/v3/store/core.py | 15 ++-------- src/zarr/v3/store/local.py | 2 +- src/zarr/v3/sync.py | 8 ++---- 11 files changed, 47 insertions(+), 50 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 79344604a..10aff8b4c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,6 @@ repos: hooks: - id: mypy files: src - exclude: ^src/zarr/v3 args: [] additional_dependencies: - types-redis diff --git a/src/zarr/v3/abc/metadata.py b/src/zarr/v3/abc/metadata.py index bdd2f86d5..4fcabf72a 100644 --- a/src/zarr/v3/abc/metadata.py +++ b/src/zarr/v3/abc/metadata.py @@ -5,11 +5,12 @@ from typing import Dict from typing_extensions import Self -from dataclasses import fields +from dataclasses import fields, dataclass from zarr.v3.common import JSON +@dataclass(frozen=True) class Metadata: def to_dict(self) -> JSON: """ diff --git a/src/zarr/v3/array.py b/src/zarr/v3/array.py index 632f7d8ec..c0a00a624 100644 --- a/src/zarr/v3/array.py +++ b/src/zarr/v3/array.py @@ -182,7 +182,7 @@ def shape(self) -> ChunkCoords: @property def size(self) -> int: - return np.prod(self.metadata.shape) + return np.prod(self.metadata.shape).item() @property def dtype(self) -> np.dtype: diff --git a/src/zarr/v3/chunk_grids.py b/src/zarr/v3/chunk_grids.py index 6c4832379..b0a2a7bb3 100644 --- a/src/zarr/v3/chunk_grids.py +++ b/src/zarr/v3/chunk_grids.py @@ -20,7 +20,7 @@ class ChunkGrid(Metadata): @classmethod def from_dict(cls, data: Dict[str, JSON]) -> ChunkGrid: if isinstance(data, ChunkGrid): - return data # type: ignore + return data name_parsed, _ = parse_named_configuration(data) if name_parsed == "regular": diff --git a/src/zarr/v3/chunk_key_encodings.py b/src/zarr/v3/chunk_key_encodings.py index e4339240e..9889a2f04 100644 --- a/src/zarr/v3/chunk_key_encodings.py +++ b/src/zarr/v3/chunk_key_encodings.py @@ -1,6 +1,6 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, Literal +from typing import TYPE_CHECKING, Dict, Literal, cast from dataclasses import dataclass from zarr.v3.abc.metadata import Metadata @@ -19,7 +19,7 @@ def parse_separator(data: JSON) -> SeparatorLiteral: if data not in (".", "/"): raise ValueError(f"Expected an '.' or '/' separator. Got {data} instead.") - return data # type: ignore + return cast(SeparatorLiteral, data) @dataclass(frozen=True) @@ -35,7 +35,7 @@ def __init__(self, *, separator: SeparatorLiteral) -> None: @classmethod def from_dict(cls, data: Dict[str, JSON]) -> ChunkKeyEncoding: if isinstance(data, ChunkKeyEncoding): - return data # type: ignore + return data name_parsed, configuration_parsed = parse_named_configuration(data) if name_parsed == "default": diff --git a/src/zarr/v3/codecs/transpose.py b/src/zarr/v3/codecs/transpose.py index f214d1e7f..b663230e3 100644 --- a/src/zarr/v3/codecs/transpose.py +++ b/src/zarr/v3/codecs/transpose.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Iterable +from typing import TYPE_CHECKING, Dict, Iterable, Union, cast from dataclasses import dataclass, replace @@ -16,12 +16,12 @@ from zarr.v3.codecs.registry import register_codec -def parse_transpose_order(data: JSON) -> Tuple[int]: +def parse_transpose_order(data: Union[JSON, Iterable[int]]) -> Tuple[int, ...]: if not isinstance(data, Iterable): raise TypeError(f"Expected an iterable. Got {data} instead.") if not all(isinstance(a, int) for a in data): raise TypeError(f"Expected an iterable of integers. Got {data} instead.") - return tuple(data) # type: ignore[return-value] + return tuple(cast(Iterable[int], data)) @dataclass(frozen=True) @@ -31,7 +31,7 @@ class TransposeCodec(ArrayArrayCodec): order: Tuple[int, ...] def __init__(self, *, order: ChunkCoordsLike) -> None: - order_parsed = parse_transpose_order(order) # type: ignore[arg-type] + order_parsed = parse_transpose_order(order) object.__setattr__(self, "order", order_parsed) diff --git a/src/zarr/v3/group.py b/src/zarr/v3/group.py index acd5ca0d6..0012a77a8 100644 --- a/src/zarr/v3/group.py +++ b/src/zarr/v3/group.py @@ -4,7 +4,7 @@ import asyncio import json import logging -from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, Iterator, List +from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List from zarr.v3.abc.metadata import Metadata from zarr.v3.array import AsyncArray, Array @@ -46,11 +46,11 @@ def to_bytes(self) -> Dict[str, bytes]: return {ZARR_JSON: json.dumps(self.to_dict()).encode()} else: return { - ZGROUP_JSON: self.zarr_format, + ZGROUP_JSON: json.dumps({"zarr_format": 2}).encode(), ZATTRS_JSON: json.dumps(self.attributes).encode(), } - def __init__(self, attributes: Dict[str, Any] = None, zarr_format: Literal[2, 3] = 3): + def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3): attributes_parsed = parse_attributes(attributes) zarr_format_parsed = parse_zarr_format(zarr_format) @@ -104,7 +104,7 @@ async def open( zarr_format: Literal[2, 3] = 3, ) -> AsyncGroup: store_path = make_store_path(store) - zarr_json_bytes = await (store_path / ZARR_JSON).get_async() + zarr_json_bytes = await (store_path / ZARR_JSON).get() assert zarr_json_bytes is not None # TODO: consider trying to autodiscover the zarr-format here @@ -139,7 +139,7 @@ def from_dict( store_path: StorePath, data: Dict[str, Any], runtime_configuration: RuntimeConfiguration, - ) -> Group: + ) -> AsyncGroup: group = cls( metadata=GroupMetadata.from_dict(data), store_path=store_path, @@ -168,10 +168,12 @@ async def getitem( zarr_json = json.loads(zarr_json_bytes) if zarr_json["node_type"] == "group": return type(self).from_dict(store_path, zarr_json, self.runtime_configuration) - if zarr_json["node_type"] == "array": + elif zarr_json["node_type"] == "array": return AsyncArray.from_dict( store_path, zarr_json, runtime_configuration=self.runtime_configuration ) + else: + raise ValueError(f"unexpected node_type: {zarr_json['node_type']}") elif self.metadata.zarr_format == 2: # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? # This guarantees that we will always make at least one extra request to the store @@ -271,7 +273,7 @@ def __repr__(self): async def nchildren(self) -> int: raise NotImplementedError - async def children(self) -> AsyncIterator[AsyncArray, AsyncGroup]: + async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]: raise NotImplementedError async def contains(self, child: str) -> bool: @@ -381,8 +383,12 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group new_metadata = replace(self.metadata, attributes=new_attributes) # Write new metadata - await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes()) - return replace(self, metadata=new_metadata) + to_save = new_metadata.to_bytes() + awaitables = [(self.store_path / key).set(value) for key, value in to_save.items()] + await asyncio.gather(*awaitables) + + async_group = replace(self._async_group, metadata=new_metadata) + return replace(self, _async_group=async_group) @property def metadata(self) -> GroupMetadata: @@ -396,34 +402,38 @@ def attrs(self) -> Attributes: def info(self): return self._async_group.info + @property + def store_path(self) -> StorePath: + return self._async_group.store_path + def update_attributes(self, new_attributes: Dict[str, Any]): self._sync(self._async_group.update_attributes(new_attributes)) return self @property def nchildren(self) -> int: - return self._sync(self._async_group.nchildren) + return self._sync(self._async_group.nchildren()) @property - def children(self) -> List[Array, Group]: - _children = self._sync_iter(self._async_group.children) + def children(self) -> List[Union[Array, Group]]: + _children = self._sync_iter(self._async_group.children()) return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children] def __contains__(self, child) -> bool: return self._sync(self._async_group.contains(child)) - def group_keys(self) -> Iterator[str]: - return self._sync_iter(self._async_group.group_keys) + def group_keys(self) -> List[str]: + return self._sync_iter(self._async_group.group_keys()) def groups(self) -> List[Group]: # TODO: in v2 this was a generator that return key: Group - return [Group(obj) for obj in self._sync_iter(self._async_group.groups)] + return [Group(obj) for obj in self._sync_iter(self._async_group.groups())] def array_keys(self) -> List[str]: - return self._sync_iter(self._async_group.array_keys) + return self._sync_iter(self._async_group.array_keys()) def arrays(self) -> List[Array]: - return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)] + return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())] def tree(self, expand=False, level=None) -> Any: return self._sync(self._async_group.tree(expand=expand, level=level)) diff --git a/src/zarr/v3/metadata.py b/src/zarr/v3/metadata.py index de3055abd..a5e892731 100644 --- a/src/zarr/v3/metadata.py +++ b/src/zarr/v3/metadata.py @@ -1,6 +1,6 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, cast, Dict, Iterable +from typing import TYPE_CHECKING, cast, Dict, Iterable, Any from dataclasses import dataclass, field import json import numpy as np @@ -10,7 +10,7 @@ if TYPE_CHECKING: - from typing import Any, Literal, Union, List, Optional, Tuple + from typing import Literal, Union, List, Optional, Tuple from zarr.v3.codecs.pipeline import CodecPipeline @@ -244,7 +244,7 @@ class ArrayV2Metadata(Metadata): filters: Optional[List[Dict[str, Any]]] = None dimension_separator: Literal[".", "/"] = "." compressor: Optional[Dict[str, Any]] = None - attributes: Optional[Dict[str, Any]] = field(default_factory=dict) + attributes: Optional[Dict[str, Any]] = cast(Dict[str, Any], field(default_factory=dict)) zarr_format: Literal[2] = field(init=False, default=2) def __init__( diff --git a/src/zarr/v3/store/core.py b/src/zarr/v3/store/core.py index 0ef1c8569..16714d9e3 100644 --- a/src/zarr/v3/store/core.py +++ b/src/zarr/v3/store/core.py @@ -5,6 +5,7 @@ from zarr.v3.common import BytesLike from zarr.v3.abc.store import Store +from zarr.v3.store.local import LocalStore def _dereference_path(root: str, path: str) -> str: @@ -24,10 +25,6 @@ def __init__(self, store: Store, path: Optional[str] = None): self.store = store self.path = path or "" - @classmethod - def from_path(cls, pth: Path) -> StorePath: - return cls(Store.from_path(pth)) - async def get( self, byte_range: Optional[Tuple[int, Optional[int]]] = None ) -> Optional[BytesLike]: @@ -70,14 +67,6 @@ def make_store_path(store_like: StoreLike) -> StorePath: return store_like elif isinstance(store_like, Store): return StorePath(store_like) - # elif isinstance(store_like, Path): - # return StorePath(Store.from_path(store_like)) elif isinstance(store_like, str): - try: - from upath import UPath - - return StorePath(Store.from_path(UPath(store_like))) - except ImportError as e: - raise e - # return StorePath(LocalStore(Path(store_like))) + return StorePath(LocalStore(Path(store_like))) raise TypeError diff --git a/src/zarr/v3/store/local.py b/src/zarr/v3/store/local.py index a62eea20f..c3da11045 100644 --- a/src/zarr/v3/store/local.py +++ b/src/zarr/v3/store/local.py @@ -146,7 +146,7 @@ async def list_prefix(self, prefix: str) -> List[str]: """ def _list_prefix(root: Path, prefix: str) -> List[str]: - files = [p for p in (root / prefix).rglob("*") if p.is_file()] + files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()] return files return await to_thread(_list_prefix, self.root, prefix) diff --git a/src/zarr/v3/sync.py b/src/zarr/v3/sync.py index f0996c019..fcc8e7b27 100644 --- a/src/zarr/v3/sync.py +++ b/src/zarr/v3/sync.py @@ -5,7 +5,6 @@ from typing import ( Any, AsyncIterator, - Callable, Coroutine, List, Optional, @@ -112,11 +111,10 @@ def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T: # this should allow us to better type the sync wrapper return sync(coroutine, loop=self._sync_configuration.asyncio_loop) - def _sync_iter( - self, func: Callable[P, AsyncIterator[T]], *args: P.args, **kwargs: P.kwargs - ) -> List[T]: + def _sync_iter(self, coroutine: Coroutine[Any, Any, AsyncIterator[T]]) -> List[T]: async def iter_to_list() -> List[T]: # TODO: replace with generators so we don't materialize the entire iterator at once - return [item async for item in func(*args, **kwargs)] + async_iterator = await coroutine + return [item async for item in async_iterator] return self._sync(iter_to_list())