Skip to content

Commit

Permalink
Resolve Mypy erorrs in v3 branch (zarr-developers#1692)
Browse files Browse the repository at this point in the history
* refactor(v3): Using appropriate types

* fix(v3): Typing fixes + minor code fixes

* fix(v3): _sync_iter works with coroutines

* docs(v3/store/core.py): clearer comment

* fix(metadata.py): Use Any outside TYPE_CHECKING for Pydantic

* fix(zarr/v3): correct zarr format + remove unused method

* fix(v3/store/core.py): Potential suggestion on handling str store_like

* refactor(zarr/v3): Add more typing

* ci(.pre-commit-config.yaml): zarr v3 mypy checks turned on in pre-commit
  • Loading branch information
DahnJ authored and d-v-b committed Apr 10, 2024
1 parent 4b20501 commit eab9fd2
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 50 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Expand Up @@ -31,7 +31,6 @@ repos:
hooks:
- id: mypy
files: src
exclude: ^src/zarr/v3
args: []
additional_dependencies:
- types-redis
Expand Down
3 changes: 2 additions & 1 deletion src/zarr/v3/abc/metadata.py
Expand Up @@ -5,11 +5,12 @@
from typing import Dict
from typing_extensions import Self

from dataclasses import fields
from dataclasses import fields, dataclass

from zarr.v3.common import JSON


@dataclass(frozen=True)
class Metadata:
def to_dict(self) -> JSON:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/v3/array.py
Expand Up @@ -182,7 +182,7 @@ def shape(self) -> ChunkCoords:

@property
def size(self) -> int:
return np.prod(self.metadata.shape)
return np.prod(self.metadata.shape).item()

@property
def dtype(self) -> np.dtype:
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/v3/chunk_grids.py
Expand Up @@ -20,7 +20,7 @@ class ChunkGrid(Metadata):
@classmethod
def from_dict(cls, data: Dict[str, JSON]) -> ChunkGrid:
if isinstance(data, ChunkGrid):
return data # type: ignore
return data

name_parsed, _ = parse_named_configuration(data)
if name_parsed == "regular":
Expand Down
6 changes: 3 additions & 3 deletions src/zarr/v3/chunk_key_encodings.py
@@ -1,6 +1,6 @@
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, Literal
from typing import TYPE_CHECKING, Dict, Literal, cast
from dataclasses import dataclass
from zarr.v3.abc.metadata import Metadata

Expand All @@ -19,7 +19,7 @@
def parse_separator(data: JSON) -> SeparatorLiteral:
if data not in (".", "/"):
raise ValueError(f"Expected an '.' or '/' separator. Got {data} instead.")
return data # type: ignore
return cast(SeparatorLiteral, data)


@dataclass(frozen=True)
Expand All @@ -35,7 +35,7 @@ def __init__(self, *, separator: SeparatorLiteral) -> None:
@classmethod
def from_dict(cls, data: Dict[str, JSON]) -> ChunkKeyEncoding:
if isinstance(data, ChunkKeyEncoding):
return data # type: ignore
return data

name_parsed, configuration_parsed = parse_named_configuration(data)
if name_parsed == "default":
Expand Down
8 changes: 4 additions & 4 deletions src/zarr/v3/codecs/transpose.py
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Dict, Iterable
from typing import TYPE_CHECKING, Dict, Iterable, Union, cast

from dataclasses import dataclass, replace

Expand All @@ -16,12 +16,12 @@
from zarr.v3.codecs.registry import register_codec


def parse_transpose_order(data: JSON) -> Tuple[int]:
def parse_transpose_order(data: Union[JSON, Iterable[int]]) -> Tuple[int, ...]:
if not isinstance(data, Iterable):
raise TypeError(f"Expected an iterable. Got {data} instead.")
if not all(isinstance(a, int) for a in data):
raise TypeError(f"Expected an iterable of integers. Got {data} instead.")
return tuple(data) # type: ignore[return-value]
return tuple(cast(Iterable[int], data))


@dataclass(frozen=True)
Expand All @@ -31,7 +31,7 @@ class TransposeCodec(ArrayArrayCodec):
order: Tuple[int, ...]

def __init__(self, *, order: ChunkCoordsLike) -> None:
order_parsed = parse_transpose_order(order) # type: ignore[arg-type]
order_parsed = parse_transpose_order(order)

object.__setattr__(self, "order", order_parsed)

Expand Down
44 changes: 27 additions & 17 deletions src/zarr/v3/group.py
Expand Up @@ -4,7 +4,7 @@
import asyncio
import json
import logging
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, Iterator, List
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List
from zarr.v3.abc.metadata import Metadata

from zarr.v3.array import AsyncArray, Array
Expand Down Expand Up @@ -46,11 +46,11 @@ def to_bytes(self) -> Dict[str, bytes]:
return {ZARR_JSON: json.dumps(self.to_dict()).encode()}
else:
return {
ZGROUP_JSON: self.zarr_format,
ZGROUP_JSON: json.dumps({"zarr_format": 2}).encode(),
ZATTRS_JSON: json.dumps(self.attributes).encode(),
}

def __init__(self, attributes: Dict[str, Any] = None, zarr_format: Literal[2, 3] = 3):
def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3):
attributes_parsed = parse_attributes(attributes)
zarr_format_parsed = parse_zarr_format(zarr_format)

Expand Down Expand Up @@ -104,7 +104,7 @@ async def open(
zarr_format: Literal[2, 3] = 3,
) -> AsyncGroup:
store_path = make_store_path(store)
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
zarr_json_bytes = await (store_path / ZARR_JSON).get()
assert zarr_json_bytes is not None

# TODO: consider trying to autodiscover the zarr-format here
Expand Down Expand Up @@ -139,7 +139,7 @@ def from_dict(
store_path: StorePath,
data: Dict[str, Any],
runtime_configuration: RuntimeConfiguration,
) -> Group:
) -> AsyncGroup:
group = cls(
metadata=GroupMetadata.from_dict(data),
store_path=store_path,
Expand Down Expand Up @@ -168,10 +168,12 @@ async def getitem(
zarr_json = json.loads(zarr_json_bytes)
if zarr_json["node_type"] == "group":
return type(self).from_dict(store_path, zarr_json, self.runtime_configuration)
if zarr_json["node_type"] == "array":
elif zarr_json["node_type"] == "array":
return AsyncArray.from_dict(
store_path, zarr_json, runtime_configuration=self.runtime_configuration
)
else:
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
elif self.metadata.zarr_format == 2:
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
# This guarantees that we will always make at least one extra request to the store
Expand Down Expand Up @@ -271,7 +273,7 @@ def __repr__(self):
async def nchildren(self) -> int:
raise NotImplementedError

async def children(self) -> AsyncIterator[AsyncArray, AsyncGroup]:
async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]:
raise NotImplementedError

async def contains(self, child: str) -> bool:
Expand Down Expand Up @@ -381,8 +383,12 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group
new_metadata = replace(self.metadata, attributes=new_attributes)

# Write new metadata
await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes())
return replace(self, metadata=new_metadata)
to_save = new_metadata.to_bytes()
awaitables = [(self.store_path / key).set(value) for key, value in to_save.items()]
await asyncio.gather(*awaitables)

async_group = replace(self._async_group, metadata=new_metadata)
return replace(self, _async_group=async_group)

@property
def metadata(self) -> GroupMetadata:
Expand All @@ -396,34 +402,38 @@ def attrs(self) -> Attributes:
def info(self):
return self._async_group.info

@property
def store_path(self) -> StorePath:
return self._async_group.store_path

def update_attributes(self, new_attributes: Dict[str, Any]):
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nchildren(self) -> int:
return self._sync(self._async_group.nchildren)
return self._sync(self._async_group.nchildren())

@property
def children(self) -> List[Array, Group]:
_children = self._sync_iter(self._async_group.children)
def children(self) -> List[Union[Array, Group]]:
_children = self._sync_iter(self._async_group.children())
return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]

def __contains__(self, child) -> bool:
return self._sync(self._async_group.contains(child))

def group_keys(self) -> Iterator[str]:
return self._sync_iter(self._async_group.group_keys)
def group_keys(self) -> List[str]:
return self._sync_iter(self._async_group.group_keys())

def groups(self) -> List[Group]:
# TODO: in v2 this was a generator that return key: Group
return [Group(obj) for obj in self._sync_iter(self._async_group.groups)]
return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]

def array_keys(self) -> List[str]:
return self._sync_iter(self._async_group.array_keys)
return self._sync_iter(self._async_group.array_keys())

def arrays(self) -> List[Array]:
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)]
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]

def tree(self, expand=False, level=None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
6 changes: 3 additions & 3 deletions src/zarr/v3/metadata.py
@@ -1,6 +1,6 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, cast, Dict, Iterable
from typing import TYPE_CHECKING, cast, Dict, Iterable, Any
from dataclasses import dataclass, field
import json
import numpy as np
Expand All @@ -10,7 +10,7 @@


if TYPE_CHECKING:
from typing import Any, Literal, Union, List, Optional, Tuple
from typing import Literal, Union, List, Optional, Tuple
from zarr.v3.codecs.pipeline import CodecPipeline


Expand Down Expand Up @@ -244,7 +244,7 @@ class ArrayV2Metadata(Metadata):
filters: Optional[List[Dict[str, Any]]] = None
dimension_separator: Literal[".", "/"] = "."
compressor: Optional[Dict[str, Any]] = None
attributes: Optional[Dict[str, Any]] = field(default_factory=dict)
attributes: Optional[Dict[str, Any]] = cast(Dict[str, Any], field(default_factory=dict))
zarr_format: Literal[2] = field(init=False, default=2)

def __init__(
Expand Down
15 changes: 2 additions & 13 deletions src/zarr/v3/store/core.py
Expand Up @@ -5,6 +5,7 @@

from zarr.v3.common import BytesLike
from zarr.v3.abc.store import Store
from zarr.v3.store.local import LocalStore


def _dereference_path(root: str, path: str) -> str:
Expand All @@ -24,10 +25,6 @@ def __init__(self, store: Store, path: Optional[str] = None):
self.store = store
self.path = path or ""

@classmethod
def from_path(cls, pth: Path) -> StorePath:
return cls(Store.from_path(pth))

async def get(
self, byte_range: Optional[Tuple[int, Optional[int]]] = None
) -> Optional[BytesLike]:
Expand Down Expand Up @@ -70,14 +67,6 @@ def make_store_path(store_like: StoreLike) -> StorePath:
return store_like
elif isinstance(store_like, Store):
return StorePath(store_like)
# elif isinstance(store_like, Path):
# return StorePath(Store.from_path(store_like))
elif isinstance(store_like, str):
try:
from upath import UPath

return StorePath(Store.from_path(UPath(store_like)))
except ImportError as e:
raise e
# return StorePath(LocalStore(Path(store_like)))
return StorePath(LocalStore(Path(store_like)))
raise TypeError
2 changes: 1 addition & 1 deletion src/zarr/v3/store/local.py
Expand Up @@ -146,7 +146,7 @@ async def list_prefix(self, prefix: str) -> List[str]:
"""

def _list_prefix(root: Path, prefix: str) -> List[str]:
files = [p for p in (root / prefix).rglob("*") if p.is_file()]
files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()]
return files

return await to_thread(_list_prefix, self.root, prefix)
Expand Down
8 changes: 3 additions & 5 deletions src/zarr/v3/sync.py
Expand Up @@ -5,7 +5,6 @@
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
List,
Optional,
Expand Down Expand Up @@ -112,11 +111,10 @@ def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
# this should allow us to better type the sync wrapper
return sync(coroutine, loop=self._sync_configuration.asyncio_loop)

def _sync_iter(
self, func: Callable[P, AsyncIterator[T]], *args: P.args, **kwargs: P.kwargs
) -> List[T]:
def _sync_iter(self, coroutine: Coroutine[Any, Any, AsyncIterator[T]]) -> List[T]:
async def iter_to_list() -> List[T]:
# TODO: replace with generators so we don't materialize the entire iterator at once
return [item async for item in func(*args, **kwargs)]
async_iterator = await coroutine
return [item async for item in async_iterator]

return self._sync(iter_to_list())

0 comments on commit eab9fd2

Please sign in to comment.