diff --git a/pynucastro/networks/rate_collection.py b/pynucastro/networks/rate_collection.py index fd313404f..aa775db4f 100644 --- a/pynucastro/networks/rate_collection.py +++ b/pynucastro/networks/rate_collection.py @@ -127,6 +127,12 @@ def set_solar_like(self, Z=0.02): self.normalize() + def set_array(self, arr): + """ set all species from a sequence of mass fractions, in the same + order as returned by get_nuclei() """ + for i, k in enumerate(self.X): + self.X[k] = arr[i] + def set_all(self, xval): """ set all species to a particular value """ for k in self.X: @@ -137,6 +143,22 @@ def set_equal(self): for k in self.X: self.X[k] = 1.0 / len(self.X) + def set_random(self, alpha=None, seed=None): + """ set all species using a Dirichlet distribution with + parameters alpha and specified rng seed """ + # initializes random seed + rng = np.random.default_rng(seed) + + # default is a flat Dirichlet distribution + if alpha is None: + alpha = np.ones(len(self.X)) + + fracs = rng.dirichlet(alpha) + self.set_array(fracs) + + # ensures exact normalization + self.normalize() + def set_nuc(self, name, xval): """ set nuclei name to the mass fraction xval """ nuc = Nucleus.cast(name) diff --git a/pynucastro/networks/tests/test_composition.py b/pynucastro/networks/tests/test_composition.py index e85003dd5..3ee3d6cfc 100644 --- a/pynucastro/networks/tests/test_composition.py +++ b/pynucastro/networks/tests/test_composition.py @@ -49,6 +49,10 @@ def test_set_equal(self, nuclei, comp): comp.set_equal() assert comp.X[nuclei[0]] == approx(1.0 / len(nuclei)) + def test_set_random(self, comp): + comp.set_random(seed=0) + assert comp.X[Nucleus("C12")] == approx(0.005076173651329372) + class TestCompositionVars: @pytest.fixture(scope="class")