Skip to content

Commit

Permalink
ENH: Use py3.10 type hinting style for type annotations using the pyu…
Browse files Browse the repository at this point in the history
…pgrade package [skip ci]
  • Loading branch information
cheginit committed Nov 5, 2022
1 parent aa512a8 commit 8c02a9d
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions pygeohydro/waterdata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Accessing data from the supported databases through their APIs."""
from __future__ import annotations

import contextlib
import io
import logging
import re
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Iterable, Sequence, Union

import async_retriever as ar
import cytoolz as tlz
Expand All @@ -13,18 +14,18 @@
import pandas as pd
import pygeoogc as ogc
import pygeoutils as geoutils
import pyproj
import xarray as xr
from loguru import logger
from pygeoogc import ServiceURL
from pygeoogc import ZeroMatchedError as ZeroMatchedErrorOGC
from pygeoogc import utils as ogc_utils
from pynhd import WaterData

from .exceptions import DataNotAvailableError, InputTypeError, InputValueError, ZeroMatchedError
from .helpers import logger

DEF_CRS = "epsg:4326"
T_FMT = "%Y-%m-%d"
EXPIRE = -1
CRSTYPE = Union[int, str, pyproj.CRS]


class NWIS:
Expand All @@ -34,7 +35,7 @@ def __init__(self) -> None:
self.url = ServiceURL().restful.nwis

@staticmethod
def retrieve_rdb(url: str, payloads: List[Dict[str, str]]) -> pd.DataFrame:
def retrieve_rdb(url: str, payloads: list[dict[str, str]]) -> pd.DataFrame:
"""Retrieve and process requests with RDB format.
Parameters
Expand Down Expand Up @@ -77,8 +78,8 @@ def retrieve_rdb(url: str, payloads: List[Dict[str, str]]) -> pd.DataFrame:

@staticmethod
def _validate_usgs_queries(
queries: List[Dict[str, str]], expanded: bool = False
) -> List[Dict[str, str]]:
queries: list[dict[str, str]], expanded: bool = False
) -> list[dict[str, str]]:
"""Validate queries to be used with USGS Site Web Service.
Parameters
Expand Down Expand Up @@ -182,7 +183,7 @@ def _validate_usgs_queries(

def get_info(
self,
queries: Union[Dict[str, str], List[Dict[str, str]]],
queries: dict[str, str] | list[dict[str, str]],
expanded: bool = False,
fix_names: bool = True,
) -> gpd.GeoDataFrame:
Expand Down Expand Up @@ -241,13 +242,11 @@ def fix_station_nm(station_nm: str) -> str:
sites["begin_date"] = pd.to_datetime(sites["begin_date"])
sites["end_date"] = pd.to_datetime(sites["end_date"])

gii = WaterData("gagesii", DEF_CRS, False)
logging.getLogger("pynhd.core").setLevel(logging.ERROR)
gii = WaterData("gagesii", 4326, False)
try:
gages = gii.byid("staid", sites.site_no.to_list())
except ZeroMatchedErrorOGC:
gages = gpd.GeoDataFrame()
logging.getLogger("pynhd.core").setLevel(logging.INFO)

if len(gages) > 0:
sites = pd.merge(
Expand Down Expand Up @@ -310,7 +309,7 @@ def get_parameter_codes(self, keyword: str) -> pd.DataFrame:
return self.retrieve_rdb(url, kwds)

@staticmethod
def _to_xarray(qobs: pd.DataFrame, long_names: Dict[str, str], mmd: bool) -> xr.Dataset:
def _to_xarray(qobs: pd.DataFrame, long_names: dict[str, str], mmd: bool) -> xr.Dataset:
"""Convert a pandas.DataFrame to an xarray.Dataset."""
ds = xr.Dataset(
data_vars={
Expand All @@ -335,7 +334,7 @@ def _to_xarray(qobs: pd.DataFrame, long_names: Dict[str, str], mmd: bool) -> xr.
return ds

@staticmethod
def _get_attrs(siteinfo: pd.DataFrame, mmd: bool) -> Tuple[Dict[str, Any], Dict[str, str]]:
def _get_attrs(siteinfo: pd.DataFrame, mmd: bool) -> tuple[dict[str, Any], dict[str, str]]:
"""Get attributes of the stations that have streaflow data."""
cols = {
"site_no": "site identification number",
Expand Down Expand Up @@ -366,10 +365,10 @@ def _get_attrs(siteinfo: pd.DataFrame, mmd: bool) -> Tuple[Dict[str, Any], Dict[

@staticmethod
def _check_inputs(
station_ids: Union[Sequence[str], str],
dates: Tuple[str, str],
utc: Optional[bool],
) -> Tuple[List[str], pd.Timestamp, pd.Timestamp]:
station_ids: Sequence[str] | str,
dates: tuple[str, str],
utc: bool | None,
) -> tuple[list[str], pd.Timestamp, pd.Timestamp]:
"""Validate inputs."""
if not isinstance(station_ids, (str, Sequence, Iterable)):
raise InputTypeError("ids", "str or list of str")
Expand Down Expand Up @@ -403,7 +402,7 @@ def _drainage_area_sqm(self, siteinfo: pd.DataFrame, freq: str) -> pd.Series:
]
info = self.get_info(queries, expanded=True)

def get_idx(ids: List[str]) -> Tuple[pd.Index, pd.Index]:
def get_idx(ids: list[str]) -> tuple[pd.Index, pd.Index]:
return info.site_no.isin(ids), area.site_no.isin(ids)

i_idx, a_idx = get_idx(sids)
Expand All @@ -424,7 +423,7 @@ def _get_streamflow(
start_dt: str,
end_dt: str,
freq: str,
kwargs: Dict[str, str],
kwargs: dict[str, str],
) -> pd.DataFrame:
"""Convert json to dataframe."""
payloads = [
Expand All @@ -440,7 +439,7 @@ def _get_streamflow(
[f"{self.url}/{freq}"] * len(payloads), [{"params": p} for p in payloads]
)

def get_site_id(site_cd: Dict[str, str]) -> str:
def get_site_id(site_cd: dict[str, str]) -> str:
"""Get site id."""
return f"{site_cd['agencyCode']}-{site_cd['value']}"

Expand All @@ -452,7 +451,7 @@ def get_site_id(site_cd: Dict[str, str]) -> str:
if len(r_ts) == 0:
raise DataNotAvailableError("discharge")

def to_df(col: str, values: Dict[str, Any]) -> pd.DataFrame:
def to_df(col: str, values: dict[str, Any]) -> pd.DataFrame:
try:
discharge = pd.DataFrame.from_records(
values, exclude=["qualifiers"], index=["dateTime"]
Expand Down Expand Up @@ -488,12 +487,12 @@ def to_df(col: str, values: Dict[str, Any]) -> pd.DataFrame:

def get_streamflow(
self,
station_ids: Union[Sequence[str], str],
dates: Tuple[str, str],
station_ids: Sequence[str] | str,
dates: tuple[str, str],
freq: str = "dv",
mmd: bool = False,
to_xarray: bool = False,
) -> Union[pd.DataFrame, xr.Dataset]:
) -> pd.DataFrame | xr.Dataset:
"""Get mean daily streamflow observations from USGS.
Parameters
Expand Down Expand Up @@ -579,7 +578,7 @@ def get_streamflow(
ms2mmd = 1000.0 * 24.0 * 3600.0
try:
qobs = pd.DataFrame(
{c: q / area_sqm.loc[c.split("-")[-1]] * ms2mmd for c, q in qobs.iteritems()}
{c: q / area_sqm.loc[c.split("-")[-1]] * ms2mmd for c, q in qobs.items()}
)
except KeyError as ex:
raise DataNotAvailableError("drainage") from ex
Expand All @@ -591,17 +590,17 @@ def get_streamflow(


def interactive_map(
bbox: Tuple[float, float, float, float],
crs: str = DEF_CRS,
nwis_kwds: Optional[Dict[str, Any]] = None,
bbox: tuple[float, float, float, float],
crs: CRSTYPE = 4326,
nwis_kwds: dict[str, Any] | None = None,
) -> folium.Map:
"""Generate an interactive map including all USGS stations within a bounding box.
Parameters
----------
bbox : tuple
List of corners in this order (west, south, east, north)
crs : str, optional
crs : str, int, or pyproj.CRS, optional
CRS of the input bounding box, defaults to EPSG:4326.
nwis_kwds : dict, optional
Optional keywords to include in the NWIS request as a dictionary like so:
Expand All @@ -622,7 +621,7 @@ def interactive_map(
>>> n_stations
10
"""
bbox = ogc.utils.match_crs(bbox, crs, DEF_CRS)
bbox = ogc.utils.match_crs(bbox, crs, 4326)
ogc.utils.check_bbox(bbox)

nwis = NWIS()
Expand Down Expand Up @@ -729,7 +728,7 @@ def get_param_table(self) -> pd.Series:
params = params[0].iloc[:29].drop(columns="Discussion")
return params.groupby("REST parameter")["Argument"].apply(",".join)

def lookup_domain_values(self, endpoint: str) -> List[str]:
def lookup_domain_values(self, endpoint: str) -> list[str]:
"""Get the domain values for the target endpoint."""
valid_endpoints = [
"statecode",
Expand Down Expand Up @@ -762,7 +761,7 @@ def _base_url(self, endpoint: str) -> str:
return f"{self.wq_url}/data/{endpoint}/search"

def get_json(
self, endpoint: str, kwds: Dict[str, str], request_method: str = "GET"
self, endpoint: str, kwds: dict[str, str], request_method: str = "GET"
) -> gpd.GeoDataFrame:
"""Get the JSON response from the Water Quality Web Service.
Expand All @@ -785,14 +784,14 @@ def get_json(
ar.retrieve_json([self._base_url(endpoint)], req_kwds, request_method=request_method)
)

def _check_kwds(self, wq_kwds: Dict[str, str]) -> None:
def _check_kwds(self, wq_kwds: dict[str, str]) -> None:
"""Check the validity of the Water Quality Web Service keyword arguments."""
invalids = [k for k in wq_kwds if k not in self.keywords.index]
if len(invalids) > 0:
raise InputValueError("wq_kwds", invalids)

def station_bybbox(
self, bbox: Tuple[float, float, float, float], wq_kwds: Optional[Dict[str, str]]
self, bbox: tuple[float, float, float, float], wq_kwds: dict[str, str] | None
) -> gpd.GeoDataFrame:
"""Retrieve station info within bounding box.
Expand Down Expand Up @@ -821,7 +820,7 @@ def station_bybbox(
return self.get_json("station", kwds)

def station_bydistance(
self, lon: float, lat: float, radius: float, wq_kwds: Optional[Dict[str, str]]
self, lon: float, lat: float, radius: float, wq_kwds: dict[str, str] | None
) -> gpd.GeoDataFrame:
"""Retrieve station within a radius (decimal miles) of a point.
Expand Down Expand Up @@ -856,7 +855,7 @@ def station_bydistance(
return self.get_json("station", kwds)

def get_csv(
self, endpoint: str, kwds: Dict[str, str], request_method: str = "GET"
self, endpoint: str, kwds: dict[str, str], request_method: str = "GET"
) -> pd.DataFrame:
"""Get the CSV response from the Water Quality Web Service.
Expand All @@ -879,7 +878,7 @@ def get_csv(
return pd.read_csv(io.BytesIO(r[0]), compression="zip")

def data_bystation(
self, station_ids: Union[str, List[str]], wq_kwds: Optional[Dict[str, str]]
self, station_ids: str | list[str], wq_kwds: dict[str, str] | None
) -> pd.DataFrame:
"""Retrieve data for a single station.
Expand Down

0 comments on commit 8c02a9d

Please sign in to comment.