diff --git a/src/vampyr/operators/derivatives.h b/src/vampyr/operators/derivatives.h index cd34e47..7592027 100644 --- a/src/vampyr/operators/derivatives.h +++ b/src/vampyr/operators/derivatives.h @@ -52,5 +52,26 @@ template void derivatives(pybind11::module &m) { )mydelimiter") // clang-format on .def(py::init &, int>(), "mra"_a, "order"_a = 1); + + + // Factory function to create derivative operators based on type + m.def("Derivative", [](const MultiResolutionAnalysis &mra, const std::string &type = "center", int order = 1) -> std::unique_ptr> { + if (type == "center") { + return std::make_unique>(mra, 0.5, 0.5); + } else if (type == "simple") { + return std::make_unique>(mra, 0.0, 0.0); + } else if (type == "forward") { + return std::make_unique>(mra, 0.0, 1.0); + } else if (type == "backward") { + return std::make_unique>(mra, 1.0, 0.0); + } else if (type == "b-spline") { + return std::make_unique>(mra, order); + } else if (type == "ph") { + return std::make_unique>(mra, order); + } else { + throw std::invalid_argument("Unknown derivative type: " + type); + } + }, "mra"_a, "type"_a, "order"_a = 1); + } } // namespace vampyr