Skip to content

Commit

Permalink
Merge pull request #100 from bjorgve/quadrature
Browse files Browse the repository at this point in the history
Add quadrature to MWNode and FunctionTree
  • Loading branch information
bjorgve committed Jun 22, 2023
2 parents 41c7c8e + 8fd7e37 commit bbe2319
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/vampyr/trees/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <filesystem>

#include <pybind11/stl/filesystem.h>
#include <pybind11/eigen.h>

#include <MRCPP/trees/FunctionNode.h>
#include <MRCPP/trees/FunctionTree.h>
Expand Down Expand Up @@ -58,6 +59,33 @@ template <int D> void trees(pybind11::module &m) {
.def("nGenNodes", &FunctionTree<D>::getNGenNodes)
.def("deleteGenerated", &FunctionTree<D>::deleteGenerated)
.def("integrate", &FunctionTree<D>::integrate)
.def("quadrature",
[](FunctionTree<D> *tree) {

if constexpr (D != 1) {
throw std::runtime_error("quadrature only implemented for 1D");
}

// Current implementation only makes sense in 1D

std::vector<double> vec_pts;
// Iterate over all end nodes
for (int i = 0; i < tree->getNEndNodes(); i++) {
MWNode<D> &node = tree->getEndMWNode(i);

Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);

// Flatten the MatrixXd and add the points from this node to the vector
vec_pts.insert(vec_pts.end(), pts.data(), pts.data() + pts.size());
}

// Now we need to create an Eigen vector from our std::vector
Eigen::VectorXd final_pts = Eigen::Map<Eigen::VectorXd, Eigen::Unaligned>(vec_pts.data(), vec_pts.size());

// Now final_pts holds all the points from all nodes
return final_pts;
})
.def(
"normalize",
[](FunctionTree<D> *out) {
Expand Down Expand Up @@ -263,6 +291,12 @@ template <int D> void trees(pybind11::module &m) {
.def("isGenNode", &MWNode<D>::isGenNode)
.def("hasParent", &MWNode<D>::hasParent)
.def("hasCoefs", &MWNode<D>::hasCoefs)
.def("quadrature",
[](MWNode<D> &node) {
Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);
return pts;
})
.def("center", &MWNode<D>::getCenter)
.def("upperBounds", &MWNode<D>::getUpperBounds)
.def("lowerBounds", &MWNode<D>::getLowerBounds)
Expand Down

0 comments on commit bbe2319

Please sign in to comment.