Skip to content

Commit

Permalink
bug: use casadi MX.interpn_linear function instead of plugin #3783 (#…
Browse files Browse the repository at this point in the history
…4077)

* bug: use casadi MX.interpn_linear function instead of plugin #3783

* bug: fix for 2d and 3d linear interpolant #3783

* cover cubic interpolation in 2d #3783

* #3783 add to changelog

---------

Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com>
  • Loading branch information
martinjrobins and agriyakhetarpal committed May 10, 2024
1 parent edb139e commit 51981c4
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
- Updated `plot_voltage_components.py` to support both `Simulation` and `Solution` objects. Added new methods in both `Simulation` and `Solution` classes for allow the syntax `simulation.plot_voltage_components` and `solution.plot_voltage_components`. Updated `test_plot_voltage_components.py` to reflect these changes ([#3723](https://github.com/pybamm-team/PyBaMM/pull/3723)).
- The SEI thickness decreased at some intervals when the 'electron-migration limited' model was used. It has been corrected ([#3622](https://github.com/pybamm-team/PyBaMM/pull/3622))
- Allow input parameters in ESOH model ([#3921](https://github.com/pybamm-team/PyBaMM/pull/3921))
- Use casadi MX.interpn_linear function instead of plugin to fix casadi_interpolant_linear.dll not found on Windows ([#4077](https://github.com/pybamm-team/PyBaMM/pull/4077))

## Optimizations

Expand Down
3 changes: 2 additions & 1 deletion pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ def __init__(
fill_value_1 = "extrapolate"
interpolating_function = interpolate.interp1d(
x1,
y.T,
y,
bounds_error=False,
fill_value=fill_value_1,
axis=0,
)
elif interpolator == "cubic":
interpolating_function = interpolate.CubicSpline(
Expand Down
32 changes: 24 additions & 8 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,31 @@ def _convert(self, symbol, t, y, y_dot, inputs):
)

if len(converted_children) == 1:
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
if solver == "linear":
test = casadi.MX.interpn_linear(
symbol.x, symbol.y.flatten(), converted_children
)
if test.shape[0] == 1 and test.shape[1] > 1:
# for some reason, pybamm.Interpolant always returns a column vector, so match that
test = test.T
return test
else:
return casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.flatten()
)(*converted_children)
elif len(converted_children) in [2, 3]:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
res = LUT(casadi.hcat(converted_children).T).T
return res
if solver == "linear":
return casadi.MX.interpn_linear(
symbol.x,
symbol.y.ravel(order="F"),
converted_children,
)
else:
LUT = casadi.interpolant(
"LUT", solver, symbol.x, symbol.y.ravel(order="F")
)
res = LUT(casadi.hcat(converted_children).T).T
return res
else: # pragma: no cover
raise ValueError(
f"Invalid converted_children count: {len(converted_children)}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_interpolation_2d(self):
# linear
y_test = np.array([0.4, 0.6])
Y = (2 * x).sum(axis=1).reshape(*[len(el) for el in x_])
for interpolator in ["linear"]:
for interpolator in ["linear", "cubic"]:
interp = pybamm.Interpolant(x_, Y, y, interpolator=interpolator)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
Expand Down

0 comments on commit 51981c4

Please sign in to comment.