Skip to content

Commit

Permalink
Composition - UserDict subclassing (#731)
Browse files Browse the repository at this point in the history
Changes the Composition class to subclass UserDict, allowing for direct key-value access of nuclei by string (e.g. comp["he4"]) or by Nucleus object (e.g. comp[Nucleus("he4")] or comp.X[Nucleus("he4")]. This also allows for dictionary operations such as for nuc, x in comp.items(): and len(comp). Also adds comp.A and comp.Z getters for dictionaries of the molar masses and charges.
  • Loading branch information
SamG-01 committed May 15, 2024
1 parent c07e7ef commit 6b82659
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 70 deletions.
115 changes: 70 additions & 45 deletions pynucastro/networks/rate_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,62 +86,93 @@ def _skip_xp(n, p, r):
return False


class Composition:
class Composition(collections.UserDict):
"""a composition holds the mass fractions of the nuclei in a network
-- useful for evaluating the rates
"""
def __init__(self, nuclei, small=1.e-16):
"""nuclei is an iterable of the nuclei in the network"""
try:
self.X = {Nucleus.cast(k): small for k in nuclei}
super().__init__({Nucleus.cast(k): small for k in nuclei})
except TypeError:
raise ValueError("must supply an iterable of Nucleus objects or strings") from None

@property
def X(self):
"""backwards-compatible getter for self.X"""
return self.data

@X.setter
def X(self, new_value):
"""backwards-compatible setter for self.X"""
self.data = new_value

def __delitem__(self, key):
super().__delitem__(Nucleus.cast(key))

def __getitem__(self, key):
return super().__getitem__(Nucleus.cast(key))

def __setitem__(self, key, value):
super().__setitem__(Nucleus.cast(key), value)

def __repr__(self):
return "Composition(" + super().__repr__() + ")"

def __str__(self):
ostr = ""
for k in self.X:
ostr += f" X({k}) : {self.X[k]}\n"
return ostr
return "".join(f" X({k}) : {v}\n" for k, v in self.items())

@property
def A(self):
""" return nuclei: molar mass pairs for elements in composition"""
return {n: n.A for n in self}

@property
def Z(self):
""" return nuclei: charge pairs for elements in composition"""
return {n: n.Z for n in self}

def get_nuclei(self):
"""return a list of Nuclei objects that make up this composition"""
return list(self.X)
return list(self)

def get_molar(self):
""" return a dictionary of molar fractions"""
return {k: v/k.A for k, v in self.items()}

def get_sum_X(self):
"""return the sum of the mass fractions"""
return math.fsum(self.X[q] for q in self.X)
return math.fsum(self.values())

def set_solar_like(self, Z=0.02):
""" approximate a solar abundance, setting p to 0.7, He4 to 0.3 - Z and
the remainder evenly distributed with Z """
num = len(self.X)
rem = Z/(num-2)
for k in self.X:
rem = Z/(len(self)-2)
for k in self:
if k == Nucleus("p"):
self.X[k] = 0.7
self[k] = 0.7
elif k.raw == "he4":
self.X[k] = 0.3 - Z
self[k] = 0.3 - Z
else:
self.X[k] = rem
self[k] = rem

self.normalize()

def set_array(self, arr):
""" set all species from a sequence of mass fractions, in the same
order as returned by get_nuclei() """
for i, k in enumerate(self.X):
self.X[k] = arr[i]
for i, k in enumerate(self):
self[k] = arr[i]

def set_all(self, xval):
""" set all species to a particular value """
for k in self.X:
self.X[k] = xval
for k in self:
self[k] = xval

def set_equal(self):
""" set all species to be equal"""
for k in self.X:
self.X[k] = 1.0 / len(self.X)
self.set_all(1.0 / len(self))

def set_random(self, alpha=None, seed=None):
""" set all species using a Dirichlet distribution with
Expand All @@ -151,7 +182,7 @@ def set_random(self, alpha=None, seed=None):

# default is a flat Dirichlet distribution
if alpha is None:
alpha = np.ones(len(self.X))
alpha = np.ones(len(self))

fracs = rng.dirichlet(alpha)
self.set_array(fracs)
Expand All @@ -161,29 +192,23 @@ def set_random(self, alpha=None, seed=None):

def set_nuc(self, name, xval):
""" set nuclei name to the mass fraction xval """
nuc = Nucleus.cast(name)
self.X[nuc] = xval
self[name] = xval

def normalize(self):
""" normalize the mass fractions to sum to 1 """
X_sum = self.get_sum_X()

for k in self.X:
self.X[k] /= X_sum

def get_molar(self):
""" return a dictionary of molar fractions"""
molar_frac = {k: v/k.A for k, v in self.X.items()}
return molar_frac
for k in self:
self[k] /= X_sum

def eval_ye(self):
""" return the electron fraction """
electron_frac = math.fsum(self.X[n] * n.Z / n.A for n in self.X) / math.fsum(self.X[n] for n in self.X)
electron_frac = math.fsum(self[n] * n.Z / n.A for n in self) / self.get_sum_X()
return electron_frac

def eval_abar(self):
""" return the mean molecular weight """
abar = math.fsum(self.X[n] / n.A for n in self.X)
abar = math.fsum(self[n] / n.A for n in self)
return 1. / abar

def eval_zbar(self):
Expand Down Expand Up @@ -217,9 +242,9 @@ def bin_as(self, nuclei, *, verbose=False, exclude=None):
# the abundance in the new, reduced composition and
# remove the nucleus from consideration for the other
# original nuclei
if ex_nuc in nuclei and ex_nuc in self.X:
if ex_nuc in nuclei and ex_nuc in self:
nuclei.remove(ex_nuc)
new_comp.X[ex_nuc] = self.X[ex_nuc]
new_comp[ex_nuc] = self[ex_nuc]
if verbose:
print(f"storing {ex_nuc} as {ex_nuc}")

Expand All @@ -229,7 +254,7 @@ def bin_as(self, nuclei, *, verbose=False, exclude=None):
# loop over our original nuclei. Find the new nucleus such
# that n_orig.A >= n_new.A. If there are multiple, then do
# the same for Z
for old_n, v in self.X.items():
for old_n, v in self.items():

if old_n in exclude:
# we should have already dealt with this above
Expand Down Expand Up @@ -260,7 +285,7 @@ def bin_as(self, nuclei, *, verbose=False, exclude=None):

if verbose:
print(f"storing {old_n} as {match_nuc}")
new_comp.X[match_nuc] += v
new_comp[match_nuc] += v

return new_comp

Expand All @@ -282,11 +307,11 @@ def plot(self, trace_threshold=0.1, hard_limit=None, size=(9, 5)):
trace_keys = []
trace_tot = 0.
main_keys = []
for k in self.X:
for k in self:
# if below threshold, count as trace element
if self.X[k] < trace_threshold:
if self[k] < trace_threshold:
trace_keys.append(k)
trace_tot += self.X[k]
trace_tot += self[k]
else:
main_keys.append(k)

Expand All @@ -296,7 +321,7 @@ def plot(self, trace_threshold=0.1, hard_limit=None, size=(9, 5)):

fig, ax = plt.subplots(1, 1, figsize=size)

ax.pie(self.X.values(), labels=self.X.keys(), autopct=lambda p: f"{p/100:0.3f}")
ax.pie(self.values(), labels=self.keys(), autopct=lambda p: f"{p/100:0.3f}")

else:
# find trace nuclei which contribute little to trace proportion
Expand All @@ -307,8 +332,8 @@ def plot(self, trace_threshold=0.1, hard_limit=None, size=(9, 5)):
limited_trace_keys = []
other_trace_tot = 0.
for k in trace_keys:
if self.X[k] < hard_limit:
other_trace_tot += self.X[k]
if self[k] < hard_limit:
other_trace_tot += self[k]
else:
limited_trace_keys.append(k)

Expand All @@ -317,7 +342,7 @@ def plot(self, trace_threshold=0.1, hard_limit=None, size=(9, 5)):
fig.subplots_adjust(wspace=0)

# pie chart parameters
main_values = [trace_tot] + [self.X[k] for k in main_keys]
main_values = [trace_tot] + [self[k] for k in main_keys]
main_labels = ['trace'] + main_keys
explode = [0.2] + [0. for i in range(len(main_keys))]

Expand All @@ -327,7 +352,7 @@ def plot(self, trace_threshold=0.1, hard_limit=None, size=(9, 5)):
labels=main_labels, explode=explode)

# bar chart parameters
trace_values = [self.X[k] for k in limited_trace_keys] + [other_trace_tot]
trace_values = [self[k] for k in limited_trace_keys] + [other_trace_tot]
trace_labels = [k.short_spec_name for k in limited_trace_keys] + ['other']
bottom = 1
width = 0.1
Expand Down Expand Up @@ -1889,7 +1914,7 @@ def gridplot(self, comp=None, color_field="X", rho=None, T=None, **kwargs):

elif color_field == "x":

values = np.array([comp.X[nuc] for nuc in nuclei])
values = np.array([comp[nuc] for nuc in nuclei])

elif color_field == "y":

Expand Down
54 changes: 29 additions & 25 deletions pynucastro/networks/tests/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,28 @@ def nuclei(self):
def comp(self, nuclei):
return networks.Composition(nuclei)

def test_getitem(self, comp):
n = Nucleus("he4")
assert comp["he4"] == comp[n] == comp.X[n]

def test_solar(self, comp):
comp.set_solar_like()

xsum = sum(comp.X.values())
xsum = sum(comp.values())

assert xsum == approx(1.0)
assert comp.X[Nucleus("h1")] == approx(0.7)
assert comp["h1"] == approx(0.7)

def test_set_all(self, nuclei, comp):
val = 1.0/len(nuclei)
comp.set_all(1.0/len(nuclei))
for n in nuclei:
assert comp.X[n] == val
for x in comp.values():
assert x == val

def test_set_nuc(self, nuclei, comp):
n = nuclei[0]
comp.set_nuc(n.raw, 0.55)
assert comp.X[n] == 0.55
assert comp[n] == 0.55

def test_get_molar(self, comp):
comp.set_solar_like(Z=0.02)
Expand All @@ -47,11 +51,11 @@ def test_get_molar(self, comp):

def test_set_equal(self, nuclei, comp):
comp.set_equal()
assert comp.X[nuclei[0]] == approx(1.0 / len(nuclei))
assert comp["o16"] == approx(1.0 / len(nuclei))

def test_set_random(self, comp):
comp.set_random(seed=0)
assert comp.X[Nucleus("C12")] == approx(0.005076173651329372)
assert comp["C12"] == approx(0.005076173651329372)


class TestCompositionVars:
Expand All @@ -71,20 +75,20 @@ def comp(self, nuclei):
def test_abar(self, comp):

# 1/Abar = sum_k X_k / A_k
Abar_inv = sum(comp.X[n] / n.A for n in comp.X)
Abar_inv = sum(comp[n] / n.A for n in comp)
assert 1.0 / Abar_inv == approx(comp.eval_abar())

def test_zbar(self, comp):

# Zbar = Abar sum_k Z_k X_k / A_k
Abar_inv = sum(comp.X[n] / n.A for n in comp.X)
Zbar = 1.0 / Abar_inv * sum(comp.X[n] * n.Z / n.A for n in comp.X)
Abar_inv = sum(comp[n] / n.A for n in comp)
Zbar = 1.0 / Abar_inv * sum(comp[n] * n.Z / n.A for n in comp)
assert Zbar == approx(comp.eval_zbar())

def test_ye(self, comp):

# Ye = sum_k Z_k X_k / A_k
Ye = sum(comp.X[n] * n.Z / n.A for n in comp.X)
Ye = sum(comp[n] * n.Z / n.A for n in comp)
assert Ye == approx(comp.eval_ye())


Expand Down Expand Up @@ -146,19 +150,19 @@ def test_bin_as(self, nuclei, comp, capsys):
orig_X = 1.0 / len(nuclei)

# we should have placed p and He4 into He4
assert c_new.X[Nucleus("he4")] == approx(2.0 * orig_X)
assert c_new[Nucleus("he4")] == approx(2.0 * orig_X)

# we should have placed C12, C13, N14, O14 into C12
assert c_new.X[Nucleus("c12")] == approx(4.0 * orig_X)
assert c_new[Nucleus("c12")] == approx(4.0 * orig_X)

# we should have placed Cr48 and Fe52 into Ti44
assert c_new.X[Nucleus("ti44")] == approx(2.0 * orig_X)
assert c_new[Nucleus("ti44")] == approx(2.0 * orig_X)

# we should have placed Cr56, Co56, and Fe56 into Fe56
assert c_new.X[Nucleus("fe56")] == approx(3.0 * orig_X)
assert c_new[Nucleus("fe56")] == approx(3.0 * orig_X)

# we should have placed Zn60 into Ni56
assert c_new.X[Nucleus("ni56")] == approx(orig_X)
assert c_new[Nucleus("ni56")] == approx(orig_X)


class TestCompBinning2:
Expand Down Expand Up @@ -193,13 +197,13 @@ def test_bin_as(self, nuclei, comp):
orig_X = 1.0 / len(nuclei)

# we should have placed d, He3. He4. He5, and C12 into p
assert c_new.X[Nucleus("p")] == approx(5.0 * orig_X)
assert c_new["p"] == approx(5.0 * orig_X)

# we should have placed O14, O15, O16, and O17 into N14
assert c_new.X[Nucleus("n14")] == approx(4.0 * orig_X)
assert c_new["n14"] == approx(4.0 * orig_X)

# we should have placed O18 into F18
assert c_new.X[Nucleus("f18")] == approx(orig_X)
assert c_new["f18"] == approx(orig_X)


class TestCompBinning3:
Expand Down Expand Up @@ -228,13 +232,13 @@ def test_bin_as(self, nuclei, comp):
orig_X = 1.0 / len(nuclei)

# we should have placed fe52, fe53 into fe52
assert c_new.X[Nucleus("fe52")] == approx(2.0 * orig_X)
assert c_new[Nucleus("fe52")] == approx(2.0 * orig_X)

# we should have placed fe54, fe55 into fe54
assert c_new.X[Nucleus("fe54")] == approx(2.0 * orig_X)
assert c_new[Nucleus("fe54")] == approx(2.0 * orig_X)

# everything else should be ni56
assert c_new.X[Nucleus("ni56")] == approx(6.0 * orig_X)
assert c_new[Nucleus("ni56")] == approx(6.0 * orig_X)

def test_bin_as_exclude(self, nuclei, comp):
"""exclude Ni56"""
Expand All @@ -248,10 +252,10 @@ def test_bin_as_exclude(self, nuclei, comp):
orig_X = 1.0 / len(nuclei)

# we should have placed fe52, fe53 into fe52
assert c_new.X[Nucleus("fe52")] == approx(2.0 * orig_X)
assert c_new["fe52"] == approx(2.0 * orig_X)

# we should have placed fe54, fe55, fe56, fe57, fe58, ni57, ni58 into fe54
assert c_new.X[Nucleus("fe54")] == approx(7.0 * orig_X)
assert c_new["fe54"] == approx(7.0 * orig_X)

# only ni56 should be ni56
assert c_new.X[Nucleus("ni56")] == approx(orig_X)
assert c_new["ni56"] == approx(orig_X)

0 comments on commit 6b82659

Please sign in to comment.