Skip to content

Commit

Permalink
Merge pull request #121 from HopkinsIDD/main-seedingbugfix
Browse files Browse the repository at this point in the history
hacky fix the seeding bug issue #114
  • Loading branch information
shauntruelove committed Nov 3, 2023
2 parents 02a21b4 + 16f81ae commit 0c30c23
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions flepimop/gempyor_pkg/src/gempyor/seeding_ic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def _DataFrame2NumbaDict(df, amounts, setup) -> nb.typed.Dict:

n_seeding_ignored_before = 0
n_seeding_ignored_after = 0

#id_seed = 0
for idx, (row_index, row) in enumerate(df.iterrows()):
if row["place"] not in setup.spatset.nodenames:
raise ValueError(
Expand All @@ -42,6 +44,7 @@ def _DataFrame2NumbaDict(df, amounts, setup) -> nb.typed.Dict:

if (row["date"].date() - setup.ti).days >= 0:
if (row["date"].date() - setup.ti).days < len(nb_seed_perday):

nb_seed_perday[(row["date"].date() - setup.ti).days] = (
nb_seed_perday[(row["date"].date() - setup.ti).days] + 1
)
Expand All @@ -51,6 +54,7 @@ def _DataFrame2NumbaDict(df, amounts, setup) -> nb.typed.Dict:
seeding_dict["seeding_destinations"][idx] = setup.compartments.get_comp_idx(destination_dict)
seeding_dict["seeding_places"][idx] = setup.spatset.nodenames.index(row["place"])
seeding_amounts[idx] = amounts[idx]
#id_seed+=1
else:
n_seeding_ignored_after += 1
else:
Expand Down Expand Up @@ -230,7 +234,15 @@ def draw_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
raise NotImplementedError(f"unknown seeding method [got: {method}]")

# Sorting by date is very important here for the seeding format necessary !!!!
print(seeding.shape)
seeding = seeding.sort_values(by="date", axis="index").reset_index()
print(seeding)
mask = (seeding['date'].dt.date > setup.ti) & (seeding['date'].dt.date <= setup.tf)
seeding = seeding.loc[mask].reset_index()
print(seeding.shape)
print(seeding)

# TODO: print.

amounts = np.zeros(len(seeding))
if method == "PoissonDistributed":
Expand All @@ -240,6 +252,7 @@ def draw_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
elif method == "FolderDraw" or method == "FromFile":
amounts = seeding["amount"]


return _DataFrame2NumbaDict(df=seeding, amounts=amounts, setup=setup)

def load_seeding(self, sim_id: int, setup) -> nb.typed.Dict:
Expand Down

0 comments on commit 0c30c23

Please sign in to comment.