Skip to content

Commit 38281d7

Browse files
committed
another batch
1 parent 06caa5e commit 38281d7

File tree

10 files changed

+57
-35
lines changed

10 files changed

+57
-35
lines changed

.github/workflows/type-check.yml

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ on:
88
paths:
99
- yt_experiments/**/*.py
1010
- pyproject.toml
11+
- requirements/typecheck.txt
1112
- .github/workflows/type-checking.yaml
1213
workflow_dispatch:
1314

@@ -17,11 +18,6 @@ jobs:
1718
name: type check
1819
timeout-minutes: 60
1920

20-
concurrency:
21-
# auto-cancel any in-progress job *on the same branch*
22-
group: ${{ github.workflow }}-${{ github.ref }}
23-
cancel-in-progress: true
24-
2521
steps:
2622
- name: Checkout repo
2723
uses: actions/checkout@v4
@@ -37,7 +33,7 @@ jobs:
3733
- name: Build
3834
run: |
3935
python3 -m pip install --upgrade pip
40-
python3 -m pip install "mypy==1.11.2"
36+
python3 -m pip install -r requirements/typecheck.txt
4137
4238
- name: list installed deps
4339
run: python -m pip list

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ warn_unreachable = true
5454
disallow_untyped_defs = false
5555
disallow_incomplete_defs = false
5656
disable_error_code = ["import-untyped", "import-not-found"]
57+
no_implicit_reexport = false

requirements/typecheck.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mypy==1.11.2

yt_xarray/accessor/_xr_to_yt.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ def __init__(
4545
self.sel_dict_type = sel_dict_type
4646

4747
# all these attributes are set in _process_selection
48-
self.selected_shape: Tuple[int] | None = None
49-
self.full_bbox = None
48+
self.selected_shape: Tuple[int, ...]
49+
self.full_bbox: npt.NDArray
5050
self.selected_bbox: npt.NDArray
51-
self.full_coords: Tuple[str]
52-
self.selected_coords: Tuple[str]
51+
self.full_coords: Tuple[str, ...]
52+
self.selected_coords: Tuple[str, ...]
5353
self.starting_indices: npt.NDArray
54-
self.selected_time: float | None = None
55-
self.ndims: int | None = None
54+
self.selected_time: float
55+
self.ndims: int
5656
self.grid_type: _GridType
5757
self.cell_widths: list[Any]
5858
self.global_dims: list[Any]
@@ -274,7 +274,9 @@ def select_from_xr(self, xr_ds: xr.Dataset, field: str) -> xr.DataArray:
274274

275275
return vars
276276

277-
def interp_validation(self, geometry):
277+
def interp_validation(
278+
self, geometry: str
279+
) -> tuple[bool, tuple[int, ...], npt.NDArray]:
278280
# checks if yt will need to interpolate to cell center
279281
# returns a tuple of (bool, shape, bbox). If the bool is True then
280282
# interpolation is required.
@@ -345,7 +347,7 @@ def interp_validation(self, geometry):
345347
known_coord_aliases = _default_known_coord_aliases.copy()
346348

347349

348-
def reset_coordinate_aliases():
350+
def reset_coordinate_aliases() -> None:
349351
kys_to_pop = [
350352
ky
351353
for ky in known_coord_aliases.keys()
@@ -399,7 +401,9 @@ def _cf_xr_coord_disamb(
399401
return None, True
400402

401403

402-
def _convert_to_yt_internal_coords(coord_list: tuple[str], xr_field: xr.DataArray):
404+
def _convert_to_yt_internal_coords(
405+
coord_list: tuple[str] | list[str], xr_field: xr.DataArray
406+
):
403407
yt_coords = []
404408
for c in coord_list:
405409
cname = c.lower()

yt_xarray/sample_data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from typing import Mapping
2+
13
import numpy as np
24
import xarray as xr
35

46

57
def load_random_xr_data(
6-
fields: dict[str, tuple[str, ...]],
7-
dims: dict[str, tuple[int | float, int | float, int]],
8+
fields: Mapping[str, tuple[str, ...]],
9+
dims: Mapping[str, tuple[int | float, int | float, int]],
810
length_unit: str | None = None,
911
) -> xr.Dataset:
1012
"""

yt_xarray/tests/test_chunking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def test_chunk_bad_length():
142142
"field0": ("x", "y", "z"),
143143
"field1": ("x", "y", "z"),
144144
}
145-
dims = {"x": (0, 1, 30), "y": (0, 2, 40), "z": (-1, 0.5, 20)}
145+
dims = {"x": (0, 1, 30), "y": (0, 2, 40), "z": (-1.0, 0.5, 20)}
146146
ds = sample_data.load_random_xr_data(fields, dims)
147147

148148
with pytest.raises(ValueError, match="The number of elements in "):
@@ -175,7 +175,7 @@ def test_chunk_info_caching():
175175
chunksizes = np.array([5, 5, 5], dtype="int")
176176
data_shape = (10, 15, 20)
177177

178-
def _get_ch():
178+
def _get_ch() -> ChunkInfo:
179179
return ChunkInfo(data_shape, chunksizes)
180180

181181
ch = _get_ch()

yt_xarray/tests/test_xr_to_yt.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import builtins
2+
from typing import Any
23

34
import numpy as np
45
import pytest
56
import xarray as xr
67
import yt
8+
from numpy import typing as npt
9+
from yt.data_objects.static_output import Dataset as ytDataset
710

811
import yt_xarray.accessor._xr_to_yt as xr2yt
912
from yt_xarray.utilities._utilities import (
@@ -39,7 +42,7 @@ def ds_xr():
3942
def test_selection_aliases(coord):
4043
for othername in xr2yt._coord_aliases[coord]:
4144
kwargs = {c_m_ds_kwargs[coord]: othername}
42-
ds = construct_minimal_ds(**kwargs)
45+
ds = construct_minimal_ds(**kwargs) # type: ignore[arg-type]
4346
fields = list(ds.data_vars)
4447
sel = xr2yt.Selection(ds, fields)
4548
assert np.all(sel.starting_indices == np.array((0, 0, 0)))
@@ -54,7 +57,13 @@ def test_selection_aliases(coord):
5457
assert xr2yt.known_coord_aliases[othername] in sel.yt_coord_names
5558

5659

57-
def _isel_tester(ds_xr, sel, fields, coord, start_index):
60+
def _isel_tester(
61+
ds_xr: xr.Dataset,
62+
sel: xr2yt.Selection,
63+
fields: list[str],
64+
coord: str,
65+
start_index: int,
66+
) -> None:
5867
dim_id = ds_xr.data_vars[fields[0]].dims.index(coord)
5968
expected = np.array((0, 0, 0))
6069
expected[dim_id] = start_index
@@ -73,9 +82,11 @@ def _isel_tester(ds_xr, sel, fields, coord, start_index):
7382

7483

7584
@pytest.mark.parametrize("coord", ("latitude", "longitude", "depth"))
76-
def test_selection_isel(ds_xr, coord):
85+
def test_selection_isel(ds_xr, coord_input):
86+
coord: str = coord_input
7787
fields = list(ds_xr.data_vars)
7888

89+
sel_dict: dict[str, Any]
7990
sel_dict = {coord: slice(1, len(ds_xr.coords[coord]))}
8091
sel_dict_type = "isel"
8192
sel = xr2yt.Selection(ds_xr, fields, sel_dict=sel_dict, sel_dict_type=sel_dict_type)
@@ -476,7 +487,9 @@ def test_add_3rd_axis_name(yt_geom):
476487
_ = xr2yt._add_3rd_axis_name("bad_geometry", expected[:-1])
477488

478489

479-
def _get_pixelized_slice(yt_ds):
490+
def _get_pixelized_slice(
491+
yt_ds: ytDataset,
492+
) -> tuple[Any, tuple[npt.NDArray, npt.NDArray | None] | npt.NDArray]:
480493
slc = yt_ds.slice(
481494
yt_ds.coordinates.axis_id["depth"],
482495
yt_ds.domain_center[yt_ds.coordinates.axis_id["depth"]],
@@ -492,7 +505,9 @@ def _get_pixelized_slice(yt_ds):
492505
return slc, vals
493506

494507

495-
def _get_ds_for_reverse_tests(stretched, use_callable, chunksizes):
508+
def _get_ds_for_reverse_tests(
509+
stretched: bool, use_callable: bool, chunksizes: tuple[int, ...]
510+
) -> ytDataset:
496511
ds = construct_minimal_ds(
497512
min_x=1,
498513
max_x=359,

yt_xarray/transformations.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import abc
2-
from typing import Callable, List, Optional, Tuple, Union
2+
from typing import Callable, List, Mapping, Optional, Tuple, Union
33

44
import numpy as np
55
import unyt
66
import xarray as xr
77
import yt
8+
from numpy import typing as npt
89
from unyt import earth_radius as _earth_radius
910

1011
from yt_xarray.accessor import _xr_to_yt
@@ -119,7 +120,7 @@ def _validate_input_coords(self, coords, input_coord_type: str):
119120

120121
return new_coords
121122

122-
def to_native(self, **coords):
123+
def to_native(self, **coords: npt.NDArray) -> list[npt.NDArray]:
123124
"""
124125
Calculate the native coordinates from transformed coordinates.
125126
@@ -131,7 +132,7 @@ def to_native(self, **coords):
131132
132133
Returns
133134
-------
134-
tuple
135+
list
135136
coordinate values in the native coordinate system, in order
136137
of the native_coords attribute.
137138
@@ -144,7 +145,7 @@ def to_native(self, **coords):
144145
new_coords = self._validate_input_coords(coords, "transformed")
145146
return self._calculate_native(**new_coords)
146147

147-
def to_transformed(self, **coords):
148+
def to_transformed(self, **coords: npt.NDArray) -> list[npt.NDArray]:
148149
"""
149150
Calculate the transformed coordinates from native coordinates.
150151
@@ -156,7 +157,7 @@ def to_transformed(self, **coords):
156157
157158
Returns
158159
-------
159-
tuple
160+
list
160161
coordinate values in the transformed coordinate system, in order
161162
of the transformed_coords attribute.
162163
@@ -223,7 +224,7 @@ class LinearScale(Transformer):
223224
224225
"""
225226

226-
def __init__(self, native_coords: Tuple[str], scale: Optional[dict] = None):
227+
def __init__(self, native_coords: Tuple[str, ...], scale: Optional[dict] = None):
227228
if scale is None:
228229
scale = {}
229230

@@ -234,20 +235,22 @@ def __init__(self, native_coords: Tuple[str], scale: Optional[dict] = None):
234235
transformed_coords = tuple([nc + "_sc" for nc in native_coords])
235236
super().__init__(native_coords, transformed_coords)
236237

237-
def _calculate_transformed(self, **coords):
238+
def _calculate_transformed(self, **coords) -> list[npt.NDarray]:
238239
transformed = []
239240
for nc_sc in self.transformed_coords:
240241
nc = nc_sc[:-3] # native coord name. e.g., go from "x_sc" to just "x"
241242
transformed.append(np.asarray(coords[nc]) * self.scale[nc])
242243
return transformed
243244

244-
def _calculate_native(self, **coords):
245+
def _calculate_native(self, **coords) -> list[npt.NDarray]:
245246
native = []
246247
for nc in self.native_coords:
247248
native.append(np.asarray(coords[nc + "_sc"]) / self.scale[nc])
248249
return native
249250

250-
def calculate_transformed_bbox(self, bbox_dict: dict) -> np.ndarray:
251+
def calculate_transformed_bbox(
252+
self, bbox_dict: Mapping[str, npt.NDArray]
253+
) -> npt.NDArray:
251254
"""
252255
Calculates a bounding box in transformed coordinates for a bounding box dictionary
253256
in native coordinates.

yt_xarray/utilities/_grid_decomposition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class ChunkInfo:
442442

443443
def __init__(
444444
self,
445-
data_shp: Tuple[int,],
445+
data_shp: Tuple[int, ...],
446446
chunksizes: np.ndarray,
447447
starting_index_offset: np.ndarray = None,
448448
):

yt_xarray/utilities/_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def construct_ds_with_extra_dim(
242242
return xr.Dataset(data_vars=data_vars)
243243

244244

245-
def _find_file(file: PathLike[str]) -> PathLike[str] | str:
245+
def _find_file(file: PathLike[str] | str) -> PathLike[str] | str:
246246
if os.path.isfile(file):
247247
return file
248248

0 commit comments

Comments
 (0)