From 9562aceea2ee132d5dc56c81c9fd3820ef4e9423 Mon Sep 17 00:00:00 2001 From: Magnar Bjorgve Date: Wed, 27 Mar 2024 17:27:27 +0100 Subject: [PATCH] Refactor derivative operator creation with factory function (#115) * Refactor derivative operator creation with factory function This commit introduces a new factory function `Derivative` to the `derivatives` module. This function streamlines the creation of derivative operators by allowing users to specify the type of derivative operator they wish to create using a string identifier. The supported types are "center", "simple", "forward", "backward", "b-spline", and "ph", which correspond to the ABGVOperator with specific parameters and the BSOperator and PHOperator with a specified order. The factory function returns a unique pointer to the created derivative operator, ensuring proper memory management and simplifying the Python interface. Additionally, this commit removes the direct exposure of the specific derivative operator classes to the Python interface, encouraging the use of the new factory function for creating derivative operators. * Default to "center" derivative --------- Co-authored-by: Stig Rune Jensen --- src/vampyr/operators/derivatives.h | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/vampyr/operators/derivatives.h b/src/vampyr/operators/derivatives.h index cd34e47f..7592027e 100644 --- a/src/vampyr/operators/derivatives.h +++ b/src/vampyr/operators/derivatives.h @@ -52,5 +52,26 @@ template void derivatives(pybind11::module &m) { )mydelimiter") // clang-format on .def(py::init &, int>(), "mra"_a, "order"_a = 1); + + + // Factory function to create derivative operators based on type + m.def("Derivative", [](const MultiResolutionAnalysis &mra, const std::string &type = "center", int order = 1) -> std::unique_ptr> { + if (type == "center") { + return std::make_unique>(mra, 0.5, 0.5); + } else if (type == "simple") { + return std::make_unique>(mra, 0.0, 0.0); + } else if (type == "forward") { + return std::make_unique>(mra, 0.0, 1.0); + } else if (type == "backward") { + return std::make_unique>(mra, 1.0, 0.0); + } else if (type == "b-spline") { + return std::make_unique>(mra, order); + } else if (type == "ph") { + return std::make_unique>(mra, order); + } else { + throw std::invalid_argument("Unknown derivative type: " + type); + } + }, "mra"_a, "type"_a, "order"_a = 1); + } } // namespace vampyr