Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "where" based ufunc masked array support decorator #98

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
63 changes: 62 additions & 1 deletion gsw/_utilities.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import wraps
from functools import wraps, reduce
from itertools import chain

import numpy as np

Expand All @@ -15,6 +16,66 @@ def masked_to_nan(arg):
else:
return np.asarray(arg, dtype=float)

def masked_array_support(f):
"""Decorator which adds support for np.ma.masked_arrays to the _wrapped_ufuncs

When one or more masked arrays are encountered as arguments or keyword
arguments, the boolean masks are all logical ORed together then logical
NOT is applied to get the ufunc.where parameter.

If no masked arrays are found, the default where argument of True is always
passed into the wrapped function as a kwarg.

If a where keyword argument is present, it will be used instead of the
masked derived value.

All args/kwargs are then passed directly to the wrapped function
"""

@wraps(f)
def wrapper(*args, **kwargs):
where = True # this is the default value for the where kwarg for all ufuncs

# the only thing done when a masked array is encountered is to figure out
# the correct thing to set the where argument to
# the order of the args and kwargs is unimportant.
# this logic inspired by how the np.ma wrapped ufuncs work
# https://github.com/numpy/numpy/blob/cafec60a5e28af98fb8798049edd7942720d2d74/numpy/ma/core.py#L1016-L1025
has_masked_args = any(
np.ma.isMaskedArray(arg) for arg in chain(args, kwargs.values())
)
if has_masked_args:
# we want getmask rather than getmaskarray for performance reasons
mask = reduce(
np.logical_or,
(np.ma.getmask(arg) for arg in chain(args, kwargs.values())),
)
where = ~mask

new_kwargs = {"where": where}
new_kwargs.update(
**kwargs
) # allow user override of the where kwarg if they passed it in

ret = f(*args, **new_kwargs)

if has_masked_args:
# I suspect based on __array_priority__ the returned values might
# not be masked arrays with mixed with other array subclasses with
# a higher prioirty
#
# masked_invalid will retain the existing mask and mask
# any new invalid values (if e.g. the result of unmasked inputs
# was nan/inf)
if isinstance(ret, tuple):
return tuple(np.ma.masked_invalid(rv) for rv in ret)
return np.ma.masked_invalid(ret)

return ret

return wrapper


def match_args_return(f):
"""
Decorator for most functions that operate on profile data.
Expand Down