Skip to content

Commit

Permalink
Merge pull request #102 from robertodr/fix-101
Browse files Browse the repository at this point in the history
__add__ and __iadd__
  • Loading branch information
bjorgve committed Jul 13, 2023
2 parents bbe2319 + 8d44a79 commit 9e2ac3b
Showing 1 changed file with 46 additions and 53 deletions.
99 changes: 46 additions & 53 deletions src/vampyr/trees/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

#include <filesystem>

#include <pybind11/stl/filesystem.h>
#include <pybind11/eigen.h>
#include <pybind11/stl/filesystem.h>

#include <MRCPP/trees/FunctionNode.h>
#include <MRCPP/trees/FunctionTree.h>
Expand All @@ -12,6 +12,18 @@
#include <MRCPP/trees/TreeIterator.h>

namespace vampyr {
template <int D>
auto impl__add__(mrcpp::FunctionTree<D> *inp_a, mrcpp::FunctionTree<D> *inp_b) -> std::unique_ptr<mrcpp::FunctionTree<D>> {
using namespace mrcpp;
auto out = std::make_unique<FunctionTree<D>>(inp_a->getMRA());
FunctionTreeVector<D> vec;
vec.push_back({1.0, inp_a});
vec.push_back({1.0, inp_b});
build_grid(*out, vec);
add(-1.0, *out, vec);
return out;
};

template <int D> void trees(pybind11::module &m) {
using namespace mrcpp;
namespace py = pybind11;
Expand All @@ -30,12 +42,11 @@ template <int D> void trees(pybind11::module &m) {
py::return_value_policy::reference_internal)
.def("rootScale", &MWTree<D>::getRootScale)
.def("depth", &MWTree<D>::getDepth)
.def(
"setZero",
[](MWTree<D> *out) {
out->setZero();
return out;
})
.def("setZero",
[](MWTree<D> *out) {
out->setZero();
return out;
})
.def("clear", &MWTree<D>::clear)
.def("setName", &MWTree<D>::setName)
.def("name", &MWTree<D>::getName)
Expand All @@ -61,37 +72,34 @@ template <int D> void trees(pybind11::module &m) {
.def("integrate", &FunctionTree<D>::integrate)
.def("quadrature",
[](FunctionTree<D> *tree) {
if constexpr (D != 1) { throw std::runtime_error("quadrature only implemented for 1D"); }

if constexpr (D != 1) {
throw std::runtime_error("quadrature only implemented for 1D");
}
// Current implementation only makes sense in 1D

// Current implementation only makes sense in 1D
std::vector<double> vec_pts;
// Iterate over all end nodes
for (int i = 0; i < tree->getNEndNodes(); i++) {
MWNode<D> &node = tree->getEndMWNode(i);

std::vector<double> vec_pts;
// Iterate over all end nodes
for (int i = 0; i < tree->getNEndNodes(); i++) {
MWNode<D> &node = tree->getEndMWNode(i);
Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);

Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);
// Flatten the MatrixXd and add the points from this node to the vector
vec_pts.insert(vec_pts.end(), pts.data(), pts.data() + pts.size());
}

// Flatten the MatrixXd and add the points from this node to the vector
vec_pts.insert(vec_pts.end(), pts.data(), pts.data() + pts.size());
}
// Now we need to create an Eigen vector from our std::vector
Eigen::VectorXd final_pts =
Eigen::Map<Eigen::VectorXd, Eigen::Unaligned>(vec_pts.data(), vec_pts.size());

// Now we need to create an Eigen vector from our std::vector
Eigen::VectorXd final_pts = Eigen::Map<Eigen::VectorXd, Eigen::Unaligned>(vec_pts.data(), vec_pts.size());

// Now final_pts holds all the points from all nodes
return final_pts;
// Now final_pts holds all the points from all nodes
return final_pts;
})
.def("normalize",
[](FunctionTree<D> *out) {
out->normalize();
return out;
})
.def(
"normalize",
[](FunctionTree<D> *out) {
out->normalize();
return out;
})
.def(
"saveTree",
[](FunctionTree<D> &obj, const std::string &filename) {
Expand Down Expand Up @@ -137,25 +145,10 @@ template <int D> void trees(pybind11::module &m) {
return out;
},
py::is_operator())
.def(
"__add__",
[](FunctionTree<D> *inp_a, FunctionTree<D> *inp_b) {
auto out = std::make_unique<FunctionTree<D>>(inp_a->getMRA());
FunctionTreeVector<D> vec;
vec.push_back({1.0, inp_a});
vec.push_back({1.0, inp_b});
build_grid(*out, vec);
add(-1.0, *out, vec);
return out;
},
py::is_operator())
.def("__add__", &impl__add__<D>, py::is_operator())
.def(
"__iadd__",
[](FunctionTree<D> *out, FunctionTree<D> *inp) {
refine_grid(*out, *inp);
out->add(1.0, *inp);
return out;
},
[](FunctionTree<D> *out, FunctionTree<D> *inp) { return impl__add__<D>(out, inp); },
py::is_operator())
.def(
"__sub__",
Expand Down Expand Up @@ -292,11 +285,11 @@ template <int D> void trees(pybind11::module &m) {
.def("hasParent", &MWNode<D>::hasParent)
.def("hasCoefs", &MWNode<D>::hasCoefs)
.def("quadrature",
[](MWNode<D> &node) {
Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);
return pts;
})
[](MWNode<D> &node) {
Eigen::MatrixXd pts;
node.getPrimitiveQuadPts(pts);
return pts;
})
.def("center", &MWNode<D>::getCenter)
.def("upperBounds", &MWNode<D>::getUpperBounds)
.def("lowerBounds", &MWNode<D>::getLowerBounds)
Expand Down

0 comments on commit 9e2ac3b

Please sign in to comment.