Skip to content

Commit

Permalink
Cache partition function interpolation in C++ (#726)
Browse files Browse the repository at this point in the history
This gives a 10-15% overall speedup in my tests with subch_base.
  • Loading branch information
yut23 committed Apr 4, 2024
1 parent ec81269 commit 7a25878
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 20 deletions.
9 changes: 8 additions & 1 deletion pynucastro/networks/base_cxx_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from pynucastro.networks.rate_collection import RateCollection
from pynucastro.networks.sympy_network_support import SympyRates
from pynucastro.rates import DerivedRate
from pynucastro.screening import get_screening_map


Expand Down Expand Up @@ -458,8 +459,14 @@ def _approx_rate_functions(self, n_indent, of):
of.write(r.function_string_cxx(dtype=self.dtype, specifiers=self.function_specifier))

def _fill_reaclib_rates(self, n_indent, of):
if self.derived_rates:
of.write(f"{self.indent*n_indent}part_fun::pf_cache_t pf_cache{{}};\n\n")

for r in self.reaclib_rates + self.derived_rates:
of.write(f"{self.indent*n_indent}rate_{r.cname()}<do_T_derivatives>(tfactors, rate, drate_dT);\n")
if isinstance(r, DerivedRate):
of.write(f"{self.indent*n_indent}rate_{r.cname()}<do_T_derivatives>(tfactors, rate, drate_dT, pf_cache);\n")
else:
of.write(f"{self.indent*n_indent}rate_{r.cname()}<do_T_derivatives>(tfactors, rate, drate_dT);\n")
of.write(f"{self.indent*n_indent}rate_eval.screened_rates(k_{r.cname()}) = rate;\n")
of.write(f"{self.indent*n_indent}if constexpr (std::is_same_v<T, rate_derivs_t>) {{\n")
of.write(f"{self.indent*n_indent} rate_eval.dscreened_rates_dT(k_{r.cname()}) = drate_dT;\n\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ namespace part_fun {

}

struct pf_cache_t {
// Store the coefficient and derivative adjacent in memory, as they're
// always accessed at the same time.
// The entries will be default-initialized to zero, which is fine since
// log10(x) is never zero.
amrex::Array2D<amrex::Real, 1, NumSpecTotal, 1, 2, Order::C> data{};
};

}

// main interface
Expand All @@ -88,6 +96,22 @@ void get_partition_function(const int inuc, [[maybe_unused]] const tf_t& tfactor

}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void get_partition_function_cached(const int inuc, const tf_t& tfactors,
part_fun::pf_cache_t& pf_cache,
amrex::Real& pf, amrex::Real& dpf_dT) {
if (pf_cache.data(inuc, 1) != 0.0_rt) {
// present in cache
amrex::ignore_unused(tfactors);
pf = pf_cache.data(inuc, 1);
dpf_dT = pf_cache.data(inuc, 2);
} else {
get_partition_function(inuc, tfactors, pf, dpf_dT);
pf_cache.data(inuc, 1) = pf;
pf_cache.data(inuc, 2) = dpf_dT;
}
}

// spins

AMREX_GPU_HOST_DEVICE AMREX_INLINE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,14 @@ namespace part_fun {

}

struct pf_cache_t {
// Store the coefficient and derivative adjacent in memory, as they're
// always accessed at the same time.
// The entries will be default-initialized to zero, which is fine since
// log10(x) is never zero.
amrex::Array2D<amrex::Real, 1, NumSpecTotal, 1, 2, Order::C> data{};
};

}

// main interface
Expand Down Expand Up @@ -183,6 +191,22 @@ void get_partition_function(const int inuc, [[maybe_unused]] const tf_t& tfactor

}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void get_partition_function_cached(const int inuc, const tf_t& tfactors,
part_fun::pf_cache_t& pf_cache,
amrex::Real& pf, amrex::Real& dpf_dT) {
if (pf_cache.data(inuc, 1) != 0.0_rt) {
// present in cache
amrex::ignore_unused(tfactors);
pf = pf_cache.data(inuc, 1);
dpf_dT = pf_cache.data(inuc, 2);
} else {
get_partition_function(inuc, tfactors, pf, dpf_dT);
pf_cache.data(inuc, 1) = pf;
pf_cache.data(inuc, 2) = dpf_dT;
}
}

// spins

AMREX_GPU_HOST_DEVICE AMREX_INLINE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void rate_He4_Fe52_to_p_Co55(const tf_t& tfactors, amrex::Real& rate, amrex::Rea

template <int do_T_derivatives>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void rate_Ni56_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT) {
void rate_Ni56_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT, [[maybe_unused]] part_fun::pf_cache_t& pf_cache) {

// Ni56 --> He4 + Fe52

Expand Down Expand Up @@ -152,7 +152,7 @@ void rate_Ni56_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amre

amrex::Real Ni56_pf, dNi56_pf_dT;
// interpolating Ni56 partition function
get_partition_function(Ni56, tfactors, Ni56_pf, dNi56_pf_dT);
get_partition_function_cached(Ni56, tfactors, pf_cache, Ni56_pf, dNi56_pf_dT);

amrex::Real He4_pf, dHe4_pf_dT;
// setting He4 partition function to 1.0 by default, independent of T
Expand All @@ -161,7 +161,7 @@ void rate_Ni56_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amre

amrex::Real Fe52_pf, dFe52_pf_dT;
// interpolating Fe52 partition function
get_partition_function(Fe52, tfactors, Fe52_pf, dFe52_pf_dT);
get_partition_function_cached(Fe52, tfactors, pf_cache, Fe52_pf, dFe52_pf_dT);

amrex::Real z_r = He4_pf * Fe52_pf;
amrex::Real z_p = Ni56_pf;
Expand All @@ -178,7 +178,7 @@ void rate_Ni56_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amre

template <int do_T_derivatives>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void rate_Ni56_to_p_Co55_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT) {
void rate_Ni56_to_p_Co55_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT, [[maybe_unused]] part_fun::pf_cache_t& pf_cache) {

// Ni56 --> p + Co55

Expand Down Expand Up @@ -209,7 +209,7 @@ void rate_Ni56_to_p_Co55_derived(const tf_t& tfactors, amrex::Real& rate, amrex:

amrex::Real Ni56_pf, dNi56_pf_dT;
// interpolating Ni56 partition function
get_partition_function(Ni56, tfactors, Ni56_pf, dNi56_pf_dT);
get_partition_function_cached(Ni56, tfactors, pf_cache, Ni56_pf, dNi56_pf_dT);

amrex::Real p_pf, dp_pf_dT;
// setting p partition function to 1.0 by default, independent of T
Expand All @@ -218,7 +218,7 @@ void rate_Ni56_to_p_Co55_derived(const tf_t& tfactors, amrex::Real& rate, amrex:

amrex::Real Co55_pf, dCo55_pf_dT;
// interpolating Co55 partition function
get_partition_function(Co55, tfactors, Co55_pf, dCo55_pf_dT);
get_partition_function_cached(Co55, tfactors, pf_cache, Co55_pf, dCo55_pf_dT);

amrex::Real z_r = p_pf * Co55_pf;
amrex::Real z_p = Ni56_pf;
Expand All @@ -235,7 +235,7 @@ void rate_Ni56_to_p_Co55_derived(const tf_t& tfactors, amrex::Real& rate, amrex:

template <int do_T_derivatives>
AMREX_GPU_HOST_DEVICE AMREX_INLINE
void rate_p_Co55_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT) {
void rate_p_Co55_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, amrex::Real& drate_dT, [[maybe_unused]] part_fun::pf_cache_t& pf_cache) {

// Co55 + p --> He4 + Fe52

Expand Down Expand Up @@ -276,11 +276,11 @@ void rate_p_Co55_to_He4_Fe52_derived(const tf_t& tfactors, amrex::Real& rate, am

amrex::Real Co55_pf, dCo55_pf_dT;
// interpolating Co55 partition function
get_partition_function(Co55, tfactors, Co55_pf, dCo55_pf_dT);
get_partition_function_cached(Co55, tfactors, pf_cache, Co55_pf, dCo55_pf_dT);

amrex::Real Fe52_pf, dFe52_pf_dT;
// interpolating Fe52 partition function
get_partition_function(Fe52, tfactors, Fe52_pf, dFe52_pf_dT);
get_partition_function_cached(Fe52, tfactors, pf_cache, Fe52_pf, dFe52_pf_dT);

amrex::Real z_r = He4_pf * Fe52_pf;
amrex::Real z_p = p_pf * Co55_pf;
Expand All @@ -306,6 +306,8 @@ fill_reaclib_rates(const tf_t& tfactors, T& rate_eval)
amrex::Real rate;
amrex::Real drate_dT;

part_fun::pf_cache_t pf_cache{};

rate_He4_Fe52_to_Ni56<do_T_derivatives>(tfactors, rate, drate_dT);
rate_eval.screened_rates(k_He4_Fe52_to_Ni56) = rate;
if constexpr (std::is_same_v<T, rate_derivs_t>) {
Expand All @@ -324,19 +326,19 @@ fill_reaclib_rates(const tf_t& tfactors, T& rate_eval)
rate_eval.dscreened_rates_dT(k_He4_Fe52_to_p_Co55) = drate_dT;

}
rate_Ni56_to_He4_Fe52_derived<do_T_derivatives>(tfactors, rate, drate_dT);
rate_Ni56_to_He4_Fe52_derived<do_T_derivatives>(tfactors, rate, drate_dT, pf_cache);
rate_eval.screened_rates(k_Ni56_to_He4_Fe52_derived) = rate;
if constexpr (std::is_same_v<T, rate_derivs_t>) {
rate_eval.dscreened_rates_dT(k_Ni56_to_He4_Fe52_derived) = drate_dT;

}
rate_Ni56_to_p_Co55_derived<do_T_derivatives>(tfactors, rate, drate_dT);
rate_Ni56_to_p_Co55_derived<do_T_derivatives>(tfactors, rate, drate_dT, pf_cache);
rate_eval.screened_rates(k_Ni56_to_p_Co55_derived) = rate;
if constexpr (std::is_same_v<T, rate_derivs_t>) {
rate_eval.dscreened_rates_dT(k_Ni56_to_p_Co55_derived) = drate_dT;

}
rate_p_Co55_to_He4_Fe52_derived<do_T_derivatives>(tfactors, rate, drate_dT);
rate_p_Co55_to_He4_Fe52_derived<do_T_derivatives>(tfactors, rate, drate_dT, pf_cache);
rate_eval.screened_rates(k_p_Co55_to_He4_Fe52_derived) = rate;
if constexpr (std::is_same_v<T, rate_derivs_t>) {
rate_eval.dscreened_rates_dT(k_p_Co55_to_He4_Fe52_derived) = drate_dT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ namespace part_fun {

}

struct pf_cache_t {
// Store the coefficient and derivative adjacent in memory, as they're
// always accessed at the same time.
// The entries will be default-initialized to zero, which is fine since
// log10(x) is never zero.
amrex::Array2D<amrex::Real, 1, NumSpecTotal, 1, 2, Order::C> data{};
};

}

// main interface
Expand All @@ -88,6 +96,22 @@ void get_partition_function(const int inuc, [[maybe_unused]] const tf_t& tfactor

}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void get_partition_function_cached(const int inuc, const tf_t& tfactors,
part_fun::pf_cache_t& pf_cache,
amrex::Real& pf, amrex::Real& dpf_dT) {
if (pf_cache.data(inuc, 1) != 0.0_rt) {
// present in cache
amrex::ignore_unused(tfactors);
pf = pf_cache.data(inuc, 1);
dpf_dT = pf_cache.data(inuc, 2);
} else {
get_partition_function(inuc, tfactors, pf, dpf_dT);
pf_cache.data(inuc, 1) = pf;
pf_cache.data(inuc, 2) = dpf_dT;
}
}

// spins

AMREX_GPU_HOST_DEVICE AMREX_INLINE
Expand Down
17 changes: 10 additions & 7 deletions pynucastro/rates/rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,16 +1209,17 @@ def function_string_py(self):
fstring += f" rate_eval.{self.fname} = rate\n\n"
return fstring

def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False):
def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False, extra_args=()):
"""
Return a string containing C++ function that computes the
rate
"""

args = ["const tf_t& tfactors", f"{dtype}& rate", f"{dtype}& drate_dT", *extra_args]
fstring = ""
fstring += "template <int do_T_derivatives>\n"
fstring += f"{specifiers}\n"
fstring += f"void rate_{self.cname()}(const tf_t& tfactors, {dtype}& rate, {dtype}& drate_dT) {{\n\n"
fstring += f"void rate_{self.cname()}({', '.join(args)}) {{\n\n"
fstring += f" // {self.rid}\n\n"
fstring += " rate = 0.0;\n"
fstring += " drate_dT = 0.0;\n\n"
Expand Down Expand Up @@ -1837,15 +1838,16 @@ def function_string_py(self):

return fstring

def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False):
def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False, extra_args=()):
"""
Return a string containing C++ function that computes the
rate
"""

self._warn_about_missing_pf_tables()

fstring = super().function_string_cxx(dtype=dtype, specifiers=specifiers, leave_open=True)
extra_args = ["[[maybe_unused]] part_fun::pf_cache_t& pf_cache", *extra_args]
fstring = super().function_string_cxx(dtype=dtype, specifiers=specifiers, leave_open=True, extra_args=extra_args)

# right now we have rate and drate_dT without the partition function
# now the partition function corrections
Expand All @@ -1858,7 +1860,7 @@ def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=Fa

if nuc.partition_function:
fstring += f" // interpolating {nuc} partition function\n"
fstring += f" get_partition_function({nuc.cindex()}, tfactors, {nuc}_pf, d{nuc}_pf_dT);\n"
fstring += f" get_partition_function_cached({nuc.cindex()}, tfactors, pf_cache, {nuc}_pf, d{nuc}_pf_dT);\n"
else:
fstring += f" // setting {nuc} partition function to 1.0 by default, independent of T\n"
fstring += f" {nuc}_pf = 1.0_rt;\n"
Expand Down Expand Up @@ -2120,7 +2122,7 @@ def function_string_py(self):
string += f" rate_eval.{self.fname} = rate\n\n"
return string

def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False):
def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=False, extra_args=()):
"""
Return a string containing C++ function that computes the
approximate rate
Expand All @@ -2129,10 +2131,11 @@ def function_string_cxx(self, dtype="double", specifiers="inline", leave_open=Fa
if self.approx_type != "ap_pg":
raise NotImplementedError("don't know how to work with this approximation")

args = ["const T& rate_eval", f"{dtype}& rate", f"{dtype}& drate_dT", *extra_args]
fstring = ""
fstring = "template <typename T>\n"
fstring += f"{specifiers}\n"
fstring += f"void rate_{self.cname()}(const T& rate_eval, {dtype}& rate, {dtype}& drate_dT) {{\n\n"
fstring += f"void rate_{self.cname()}({', '.join(args)}) {{\n\n"

if not self.is_reverse:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ namespace part_fun {

}

struct pf_cache_t {
// Store the coefficient and derivative adjacent in memory, as they're
// always accessed at the same time.
// The entries will be default-initialized to zero, which is fine since
// log10(x) is never zero.
amrex::Array2D<amrex::Real, 1, NumSpecTotal, 1, 2, Order::C> data{};
};

}

// main interface
Expand All @@ -90,6 +98,22 @@ void get_partition_function(const int inuc, [[maybe_unused]] const tf_t& tfactor

}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void get_partition_function_cached(const int inuc, const tf_t& tfactors,
part_fun::pf_cache_t& pf_cache,
amrex::Real& pf, amrex::Real& dpf_dT) {
if (pf_cache.data(inuc, 1) != 0.0_rt) {
// present in cache
amrex::ignore_unused(tfactors);
pf = pf_cache.data(inuc, 1);
dpf_dT = pf_cache.data(inuc, 2);
} else {
get_partition_function(inuc, tfactors, pf, dpf_dT);
pf_cache.data(inuc, 1) = pf;
pf_cache.data(inuc, 2) = dpf_dT;
}
}

// spins

AMREX_GPU_HOST_DEVICE AMREX_INLINE
Expand Down

0 comments on commit 7a25878

Please sign in to comment.