Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved performance for 'remove_bottom_values` #119

Open
jbusecke opened this issue Oct 19, 2021 · 0 comments
Open

Improved performance for 'remove_bottom_values` #119

jbusecke opened this issue Oct 19, 2021 · 0 comments

Comments

@jbusecke
Copy link
Owner

Recently recoded the remove_bottom_values function using numba:

from numba import float64, guvectorize
import numpy as np
import xarray as xr

@guvectorize(
    [
        (float64[:], float64[:]),
    ],
    "(n)->(n)",
    nopython=True,
)
def _remove_last_value(data, output):
    # initialize output
    output[:] = data[:]
    for i in range(len(data)-1):
        if np.isnan(output[i+1]):
            output[i] = np.nan
    # take care of boundaries
    if not np.isnan(output[-1]):
        output[-1] = np.nan

def remove_bottom_values_numba(da, dim='lev'):
    
    out = xr.apply_ufunc(
        _remove_last_value,
        da,
        input_core_dims=[[dim]],
        output_core_dims=[[dim]],
        dask="parallelized",
        output_dtypes=[da.dtype],
    )
    return out

def remove_bottom_values_recoded(ds, dim="lev", fill_val=-1e10):
    """Remove the deepest values that are not nan along the dimension `dim`"""
    # for now assume that values of `dim` increase along the dimension
    if ds[dim][0] > ds[dim][-1]:
        raise ValueError(
            f"It seems like `{dim}` has decreasing values. This is not supported yet. Please sort before."
        )
    else:
        ds_masked = xr.Dataset({va:remove_bottom_values_numba(ds[va]) for va in ds.data_vars})
        ds_masked = ds_masked.transpose(*tuple([di for di in ds.dims if di in ds_masked]))
        ds_masked = ds_masked.assign_coords({co:ds[co].transpose(*[di for di in ds.dims if di in ds[co]]) for co in ds.coords})
        ds_masked.attrs = ds.attrs
        return ds_masked

I am planning on implementing this here at some point. It might also be nice to generalize this to optionally keep only the bottom, and maybe not just leave one value, but an arbitrary amount.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant