Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Mypy erorrs in v3 branch #1692

Merged
merged 9 commits into from Apr 6, 2024
3 changes: 2 additions & 1 deletion src/zarr/v3/abc/metadata.py
Expand Up @@ -5,11 +5,12 @@
from typing import Dict
from typing_extensions import Self

from dataclasses import fields
from dataclasses import fields, dataclass

from zarr.v3.common import JSON


@dataclass(frozen=True)
class Metadata:
def to_dict(self) -> JSON:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/v3/array.py
Expand Up @@ -182,7 +182,7 @@ def shape(self) -> ChunkCoords:

@property
def size(self) -> int:
return np.prod(self.metadata.shape)
return np.prod(self.metadata.shape).item()

@property
def dtype(self) -> np.dtype:
Expand Down
44 changes: 27 additions & 17 deletions src/zarr/v3/group.py
Expand Up @@ -4,7 +4,7 @@
import asyncio
import json
import logging
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, Iterator, List
from typing import Any, Dict, Literal, Optional, Union, AsyncIterator, List
from zarr.v3.abc.metadata import Metadata

from zarr.v3.array import AsyncArray, Array
Expand Down Expand Up @@ -46,11 +46,11 @@ def to_bytes(self) -> Dict[str, bytes]:
return {ZARR_JSON: json.dumps(self.to_dict()).encode()}
else:
return {
ZGROUP_JSON: self.zarr_format,
ZGROUP_JSON: str(self.zarr_format).encode(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging that this was actually a bug in hiding. I believe it should have been something like:

Suggested change
ZGROUP_JSON: str(self.zarr_format).encode(),
ZGROUP_JSON: json.dumps({"zarr_format": 2}).encode(),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in a30af88

ZATTRS_JSON: json.dumps(self.attributes).encode(),
}

def __init__(self, attributes: Dict[str, Any] = None, zarr_format: Literal[2, 3] = 3):
def __init__(self, attributes: Optional[Dict[str, Any]] = None, zarr_format: Literal[2, 3] = 3):
attributes_parsed = parse_attributes(attributes)
zarr_format_parsed = parse_zarr_format(zarr_format)

Expand Down Expand Up @@ -104,7 +104,7 @@ async def open(
zarr_format: Literal[2, 3] = 3,
) -> AsyncGroup:
store_path = make_store_path(store)
zarr_json_bytes = await (store_path / ZARR_JSON).get_async()
zarr_json_bytes = await (store_path / ZARR_JSON).get()
assert zarr_json_bytes is not None

# TODO: consider trying to autodiscover the zarr-format here
Expand Down Expand Up @@ -139,7 +139,7 @@ def from_dict(
store_path: StorePath,
data: Dict[str, Any],
runtime_configuration: RuntimeConfiguration,
) -> Group:
) -> AsyncGroup:
group = cls(
metadata=GroupMetadata.from_dict(data),
store_path=store_path,
Expand Down Expand Up @@ -168,10 +168,12 @@ async def getitem(
zarr_json = json.loads(zarr_json_bytes)
if zarr_json["node_type"] == "group":
return type(self).from_dict(store_path, zarr_json, self.runtime_configuration)
if zarr_json["node_type"] == "array":
elif zarr_json["node_type"] == "array":
return AsyncArray.from_dict(
store_path, zarr_json, runtime_configuration=self.runtime_configuration
)
else:
raise ValueError(f"unexpected node_type: {zarr_json['node_type']}")
elif self.metadata.zarr_format == 2:
# Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs?
# This guarantees that we will always make at least one extra request to the store
Expand Down Expand Up @@ -271,7 +273,7 @@ def __repr__(self):
async def nchildren(self) -> int:
raise NotImplementedError

async def children(self) -> AsyncIterator[AsyncArray, AsyncGroup]:
async def children(self) -> AsyncIterator[Union[AsyncArray, AsyncGroup]]:
raise NotImplementedError

async def contains(self, child: str) -> bool:
Expand Down Expand Up @@ -381,8 +383,12 @@ async def update_attributes_async(self, new_attributes: Dict[str, Any]) -> Group
new_metadata = replace(self.metadata, attributes=new_attributes)

# Write new metadata
await (self.store_path / ZARR_JSON).set_async(new_metadata.to_bytes())
return replace(self, metadata=new_metadata)
to_save = new_metadata.to_bytes()
awaitables = [(self.store_path / key).set(value) for key, value in to_save.items()]
await asyncio.gather(*awaitables)

async_group = replace(self._async_group, metadata=new_metadata)
return replace(self, _async_group=async_group)

@property
def metadata(self) -> GroupMetadata:
Expand All @@ -396,34 +402,38 @@ def attrs(self) -> Attributes:
def info(self):
return self._async_group.info

@property
def store_path(self) -> StorePath:
return self._async_group.store_path

def update_attributes(self, new_attributes: Dict[str, Any]):
self._sync(self._async_group.update_attributes(new_attributes))
return self

@property
def nchildren(self) -> int:
return self._sync(self._async_group.nchildren)
return self._sync(self._async_group.nchildren())

@property
def children(self) -> List[Array, Group]:
_children = self._sync_iter(self._async_group.children)
def children(self) -> List[Union[Array, Group]]:
_children = self._sync_iter(self._async_group.children())
return [Array(obj) if isinstance(obj, AsyncArray) else Group(obj) for obj in _children]

def __contains__(self, child) -> bool:
return self._sync(self._async_group.contains(child))

def group_keys(self) -> Iterator[str]:
return self._sync_iter(self._async_group.group_keys)
def group_keys(self) -> List[str]:
return self._sync_iter(self._async_group.group_keys())

def groups(self) -> List[Group]:
# TODO: in v2 this was a generator that return key: Group
return [Group(obj) for obj in self._sync_iter(self._async_group.groups)]
return [Group(obj) for obj in self._sync_iter(self._async_group.groups())]

def array_keys(self) -> List[str]:
return self._sync_iter(self._async_group.array_keys)
return self._sync_iter(self._async_group.array_keys())

def arrays(self) -> List[Array]:
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays)]
return [Array(obj) for obj in self._sync_iter(self._async_group.arrays())]

def tree(self, expand=False, level=None) -> Any:
return self._sync(self._async_group.tree(expand=expand, level=level))
Expand Down
6 changes: 3 additions & 3 deletions src/zarr/v3/metadata.py
@@ -1,6 +1,6 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, cast, Dict, Iterable
from typing import TYPE_CHECKING, cast, Dict, Iterable, Any
from dataclasses import dataclass, field
import json
import numpy as np
Expand All @@ -10,7 +10,7 @@


if TYPE_CHECKING:
from typing import Any, Literal, Union, List, Optional, Tuple
from typing import Literal, Union, List, Optional, Tuple
from zarr.v3.codecs.pipeline import CodecPipeline


Expand Down Expand Up @@ -244,7 +244,7 @@ class ArrayV2Metadata(Metadata):
filters: Optional[List[Dict[str, Any]]] = None
dimension_separator: Literal[".", "/"] = "."
compressor: Optional[Dict[str, Any]] = None
attributes: Optional[Dict[str, Any]] = field(default_factory=dict)
attributes: Optional[Dict[str, Any]] = cast(Dict[str, Any], field(default_factory=dict))
zarr_format: Literal[2] = field(init=False, default=2)

def __init__(
Expand Down
6 changes: 4 additions & 2 deletions src/zarr/v3/store/core.py
Expand Up @@ -26,7 +26,8 @@ def __init__(self, store: Store, path: Optional[str] = None):

@classmethod
def from_path(cls, pth: Path) -> StorePath:
return cls(Store.from_path(pth))
# NOT SOLVED: This is instantiating an ABC + there is no from_path method
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not solved here, as per comment. What subclass of Store should this use?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I don't think we're using StorePath.from_path() anymore. If so, I'd be comfortable removing it here.

return cls(Store.from_path(pth)) # type: ignore

async def get(
self, byte_range: Optional[Tuple[int, Optional[int]]] = None
Expand Down Expand Up @@ -76,7 +77,8 @@ def make_store_path(store_like: StoreLike) -> StorePath:
try:
from upath import UPath

return StorePath(Store.from_path(UPath(store_like)))
# NOT SOLVED: Similar here, ABC instantiation + no from_path method
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above.

return StorePath(Store.from_path(UPath(store_like))) # type: ignore
except ImportError as e:
raise e
# return StorePath(LocalStore(Path(store_like)))
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/v3/store/local.py
Expand Up @@ -146,7 +146,7 @@ async def list_prefix(self, prefix: str) -> List[str]:
"""

def _list_prefix(root: Path, prefix: str) -> List[str]:
files = [p for p in (root / prefix).rglob("*") if p.is_file()]
files = [str(p) for p in (root / prefix).rglob("*") if p.is_file()]
return files

return await to_thread(_list_prefix, self.root, prefix)
Expand Down
8 changes: 3 additions & 5 deletions src/zarr/v3/sync.py
Expand Up @@ -5,7 +5,6 @@
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
List,
Optional,
Expand Down Expand Up @@ -112,11 +111,10 @@ def _sync(self, coroutine: Coroutine[Any, Any, T]) -> T:
# this should allow us to better type the sync wrapper
return sync(coroutine, loop=self._sync_configuration.asyncio_loop)

def _sync_iter(
self, func: Callable[P, AsyncIterator[T]], *args: P.args, **kwargs: P.kwargs
) -> List[T]:
def _sync_iter(self, coroutine: Coroutine[Any, Any, AsyncIterator[T]]) -> List[T]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the args/kwargs here to make it handle a coroutine, which is how it's so far being used.

async def iter_to_list() -> List[T]:
# TODO: replace with generators so we don't materialize the entire iterator at once
return [item async for item in func(*args, **kwargs)]
async_iterator = await coroutine
return [item async for item in async_iterator]

return self._sync(iter_to_list())