Skip to content

Commit

Permalink
Refactor derivative operator creation with factory function (#115)
Browse files Browse the repository at this point in the history
* Refactor derivative operator creation with factory function

This commit introduces a new factory function `Derivative` to the
`derivatives` module. This function streamlines the creation of
derivative operators by allowing users to specify the type of
derivative operator they wish to create using a string identifier.
The supported types are "center", "simple", "forward", "backward",
"b-spline", and "ph", which correspond to the ABGVOperator with
specific parameters and the BSOperator and PHOperator with a
specified order.

The factory function returns a unique pointer to the created
derivative operator, ensuring proper memory management and
simplifying the Python interface.

Additionally, this commit removes the direct exposure of the
specific derivative operator classes to the Python interface,
encouraging the use of the new factory function for creating
derivative operators.

* Default to "center" derivative

---------

Co-authored-by: Stig Rune Jensen <stig.r.jensen@oceanbox.io>
  • Loading branch information
bjorgve and stigrj committed Mar 27, 2024
1 parent 2c52b30 commit 9562ace
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/vampyr/operators/derivatives.h
Expand Up @@ -52,5 +52,26 @@ template <int D> void derivatives(pybind11::module &m) {
)mydelimiter")
// clang-format on
.def(py::init<const MultiResolutionAnalysis<D> &, int>(), "mra"_a, "order"_a = 1);


// Factory function to create derivative operators based on type
m.def("Derivative", [](const MultiResolutionAnalysis<D> &mra, const std::string &type = "center", int order = 1) -> std::unique_ptr<DerivativeOperator<D>> {
if (type == "center") {
return std::make_unique<ABGVOperator<D>>(mra, 0.5, 0.5);
} else if (type == "simple") {
return std::make_unique<ABGVOperator<D>>(mra, 0.0, 0.0);
} else if (type == "forward") {
return std::make_unique<ABGVOperator<D>>(mra, 0.0, 1.0);
} else if (type == "backward") {
return std::make_unique<ABGVOperator<D>>(mra, 1.0, 0.0);
} else if (type == "b-spline") {
return std::make_unique<BSOperator<D>>(mra, order);
} else if (type == "ph") {
return std::make_unique<PHOperator<D>>(mra, order);
} else {
throw std::invalid_argument("Unknown derivative type: " + type);
}
}, "mra"_a, "type"_a, "order"_a = 1);

}
} // namespace vampyr

0 comments on commit 9562ace

Please sign in to comment.