Skip to content

Commit

Permalink
Draft fitter for spline models
Browse files Browse the repository at this point in the history
  • Loading branch information
alexji committed Sep 9, 2018
1 parent 5640b7b commit 1c88ccf
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 0 deletions.
258 changes: 258 additions & 0 deletions specutils/fitting/spline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
from __future__ import print_function, division, absolute_import

import numpy as np
from scipy import interpolate

from astropy.modeling.core import FittableModel, Model
from astropy.modeling.functional_models import Shift
from astropy.modeling.parameters import Parameter
from astropy.modeling.utils import poly_map_domain, comb
from astropy.modeling.fitting import _FitterMeta, fitter_unit_support
from astropy.utils import indent, check_broadcast
from astropy.units import Quantity


__all__ = []

class SplineModel(FittableModel):
"""
Wrapper around scipy.interpolate.splrep and splev
Analogous to scipy.interpolate.UnivariateSpline() if knots unspecified,
and scipy.interpolate.LSQUnivariateSpline if knots are specified
There are two ways to make a spline model.
(1) you have the spline auto-determine knots from the data
(2) you specify the knots
"""

linear = False # I think? I have no idea?
col_fit_deriv = False # Not sure what this is

def __init__(self, degree=3, smoothing=None, knots=None, extrapolate_mode=0):
"""
Set up a spline model.
degree: degree of the spline (default 3)
In scipy fitpack, this is "k"
smoothing (optional): smoothing value for automatically determining knots
In scipy fitpack, this is "s"
By default, uses a
knots (optional): spline knots (boundaries of piecewise polynomial)
If not specified, will automatically determine knots based on
degree + smoothing
extrapolate_mode (optional): how to deal with solution outside of interval.
(see scipy.interpolate.splev)
if 0 (default): return the extrapolated value
if 1, return 0
if 2, raise a ValueError
if 3, return the boundary value
"""
self._degree = degree
self._smoothing = smoothing
self._knots = self.verify_knots(knots)
self.extrapolate_mode = extrapolate_mode

## This is used to evaluate the spline
## When None, raises an error when trying to evaluate the spline
self._tck = None

self._param_names = ()

def verify_knots(self, knots):
"""
Basic knot array vetting.
The goal of having this is to enable more useful error messages
than scipy (if needed).
"""
if knots is None: return None
knots = np.array(knots)
assert len(knots.shape) == 1, knots.shape
knots = np.sort(knots)
assert len(np.unique(knots)) == len(knots), knots
return knots

############
## Getters
############
def get_degree(self):
""" Spline degree (k in FITPACK) """
return self._degree
def get_smoothing(self):
""" Spline smoothing (s in FITPACK) """
return self._smoothing
def get_knots(self):
""" Spline knots (t in FITPACK) """
return self._knots
def get_coeffs(self):
""" Spline coefficients (c in FITPACK) """
if self._tck is not None:
return self._tck[1]
else:
raise RuntimeError("SplineModel has not been fit yet")

############
## Spline methods: not tested at all
############
def derivative(self, n=1):
if self._tck is None:
raise RuntimeError("SplineModel has not been fit yet")
else:
t, c, k = self._tck
return scipy.interpolate.BSpline.construct_fast(
t,c,k,extrapolate=(self.extrapolate_mode==0)).derivative(n)
def antiderivative(self, n=1):
if self._tck is None:
raise RuntimeError("SplineModel has not been fit yet")
else:
t, c, k = self._tck
return scipy.interpolate.BSpline.construct_fast(
t,c,k,extrapolate=(self.extrapolate_mode==0)).antiderivative(n)
def integral(self, a, b):
if self._tck is None:
raise RuntimeError("SplineModel has not been fit yet")
else:
t, c, k = self._tck
return scipy.interpolate.BSpline.construct_fast(
t,c,k,extrapolate=(self.extrapolate_mode==0)).integral(a,b)
def derivatives(self, x):
raise NotImplementedError
def roots(self):
raise NotImplementedError

############
## Setters: not really implemented or tested
############
def reset_model(self):
""" Resets model so it needs to be refit to be valid """
self._tck = None
def set_degree(self, degree):
""" Spline degree (k in FITPACK) """
raise NotImplementedError
self._degree = degree
self.reset_model()
def set_smoothing(self, smoothing):
""" Spline smoothing (s in FITPACK) """
raise NotImplementedError
self._smoothing = smoothing
self.reset_model()
def set_knots(self, knots):
""" Spline knots (t in FITPACK) """
raise NotImplementedError
self._knots = self.verify_knots(knots)
self.reset_model()

def set_model_from_tck(self, tck):
"""
Use output of scipy.interpolate.splrep
"""
self._tck = tck

def __call__(self, x, der=0):
"""
Evaluate the model with the given inputs.
der is passed to scipy.interpolate.splev
"""
if self._tck is None:
raise RuntimeError("SplineModel has not been fit yet")
return interpolate.splev(x, self._tck, der=der, ext=self.extrapolate_mode)

####################################
######### Stuff below here is stubs
@property
def param_names(self):
"""
Coefficient names generated based on the model's knots and polynomial degree.
Not Implemented
"""
raise NotImplementedError("SplineModel does not currently expose parameters")
return self._param_names

#def __getattr__(self, attr):
# """
# Fails right now. Future code:
# # From astropy.modeling.polynomial.PolynomialBase
# if self._param_names and attr in self._param_names:
# return Parameter(attr, default=0.0, model=self)
# raise AttributeError(attr)
# """
# raise NotImplementedError("SplineModel does not currently expose parameters")

#def __setattr__(self, attr, value):
# """
# Fails right now. Future code:
# # From astropy.modeling.polynomial.PolynomialBase
# if attr[0] != '_' and self._param_names and attr in self._param_names:
# param = Parameter(attr, default=0.0, model=self)
# param.__set__(self, value)
# else:
# super().__setattr__(attr, value)
# """
# raise NotImplementedError("SplineModel does not currently expose parameters")

def _generate_coeff_names(self):
names = []
degree, Nknots = self._degree, len(self._knots)
for i in range(Nknots):
for j in range(degree+1):
names.append("k{}_c{}".format(i,j))
return tuple(names)

def evaluate(self, *args, **kwargs):
return self(*args, **kwargs)



class SplineFitter(metaclass=_FitterMeta):
"""
Run a spline fit.
"""
def __init__(self):
self.fit_info = {"fp": None,
"ier": None,
"msg": None}
super().__init__()

def validate_model(self, model):
if not isinstance(model, SplineModel):
raise ValueError("model must be of type SplineModel (currently is {})".format(
type(model)))

## TODO do something about units
#@fitter_unit_support
def __call__(self, model, x, y, w=None):
"""
Fit a spline model to data.
Internally uses scipy.interpolate.splrep.
"""

self.validate_model(model)

## Case (1): fit smoothing spline
if model.get_knots() is None:
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
t=None,
k=model.get_degree(),
s=model.get_smoothing(),
task=0, full_output=True
)
## Case (2): leastsq spline
else:
knots = model.get_knots()
## TODO some sort of validation that the knots are internal, since
## this procedure automatically adds knots at the two endpoints
tck, fp, ier, msg = interpolate.splrep(x, y, w=w,
t=knots,
k=model.get_degree(),
s=model.get_smoothing(),
task=-1, full_output=True
)

model.set_model_from_tck(tck)
self.fit_info.update({"fp":fp, "ier":ier, "msg":msg})

60 changes: 60 additions & 0 deletions specutils/tests/test_spline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import astropy.units as u
import numpy as np

from astropy.modeling import models, fitting
from specutils.fitting.spline import SplineModel, SplineFitter

from scipy import interpolate

def make_data(with_errs=True):
""" Arbitrary data """
np.random.seed(348957)
x = np.linspace(0, 10, 200)
y = (x+1) - (x-5)**2. + 10.*np.exp(-0.5 * ((x-7.)/.5)**2.)
y = (y - np.min(y) + 10.)*10.
if with_errs:
ey = np.sqrt(y)
y = y + np.random.normal(0., ey, y.shape)
w = 1./y
return x, y, w

def test_spline_fit():
x, y, w = make_data()
make_plot=False

# Construct three sets of splines and their scipy equivalents
knots = np.arange(1,10)
models = [SplineModel(), SplineModel(degree=5), SplineModel(knots=knots), SplineModel(smoothing=0)]
labels = ["Deg 3", "Deg 5", "Knots", "Interpolated"]
scipyfit = [interpolate.UnivariateSpline(x,y,w),
interpolate.UnivariateSpline(x,y,w,k=5),
interpolate.LSQUnivariateSpline(x,y,knots,w=w),
interpolate.InterpolatedUnivariateSpline(x,y,w)]

fitter = SplineFitter()
for model, label, scipymodel in zip(models, labels, scipyfit):
fitter(model, x, y, w)
my_y = model(x)
sci_y = scipymodel(x)
assert np.allclose(my_y, sci_y, atol=1e-6)

if make_plot:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(x,y,'k.')
ymin, ymax = np.min(y), np.max(y)
for i,(model, label) in enumerate(zip(models, labels)):
l, = ax.plot(x, model(x), lw=1, label=label)
knots = model.get_knots()
# Hack for now
if knots is None: knots = model._tck[0]
print(knots)
dy = (ymax-ymin)/10.
dy /= i+1.
ax.vlines(knots, ymin, ymin + dy, color=l.get_color(), lw=1)
ax.legend()
plt.show()

if __name__=="__main__":
test_spline_fit()

0 comments on commit 1c88ccf

Please sign in to comment.