Skip to content

Commit

Permalink
ressurect eager-mode systematics handling
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Apr 10, 2023
1 parent b3f963a commit 8d7fe3d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 28 deletions.
14 changes: 9 additions & 5 deletions coffea/nanoevents/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _ensure_systematics(self):
Make sure that the parent object always has a field called '__systematics__'.
"""
if "__systematics__" not in awkward.fields(self):
self["__systematics__"] = {}
self["__systematics__"] = awkward.Array(len(self) * [{}])

@property
def systematics(self):
Expand Down Expand Up @@ -122,23 +122,27 @@ def add_systematic(
if what == "weight" and "__ones__" not in awkward.fields(
flat["__systematics__"]
):
flat["__systematics__", "__ones__"] = numpy.ones(
len(flat), dtype=numpy.float32
)
fields = awkward.fields(flat["__systematics__"])
as_dict = {field: flat["__systematics__", field] for field in fields}
as_dict["__ones__"] = numpy.ones(len(flat), dtype=numpy.float32)
flat["__systematics__"] = awkward.zip(as_dict, depth_limit=1)

rendered_type = flat.layout.parameters["__record__"]
as_syst_type = awkward.with_parameter(flat, "__record__", kind)
as_syst_type._build_variations(name, what, varying_function)
variations = as_syst_type.describe_variations()

flat["__systematics__", name] = awkward.zip(
fields = awkward.fields(flat["__systematics__"])
as_dict = {field: flat["__systematics__", field] for field in fields}
as_dict[name] = awkward.zip(
{
v: getattr(as_syst_type, v)(name, what, rendered_type)
for v in variations
},
depth_limit=1,
with_name=f"{name}Systematics",
)
flat["__systematics__"] = awkward.zip(as_dict, depth_limit=1)

self["__systematics__"] = wrap(flat["__systematics__"])
self.behavior[("__typestr__", f"{name}Systematics")] = f"{kind}"
Expand Down
32 changes: 17 additions & 15 deletions coffea/nanoevents/methods/systematics/UpDownSystematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ def _build_variations(self, name, what, varying_function, *args, **kwargs):
self[what] if what != "weight" else self["__systematics__", "__ones__"]
)

self["__systematics__", f"__{name}__"] = awkward.virtual(
varying_function,
args=(whatarray, *args),
kwargs=kwargs,
length=len(whatarray),
fields = awkward.fields(self["__systematics__"])
as_dict = {field: self["__systematics__", field] for field in fields}
as_dict[f"__{name}__"] = varying_function(
whatarray,
*args,
**kwargs,
)
self["__systematics__"] = awkward.zip(as_dict, depth_limit=1)

def describe_variations(self):
"""Show the map of variation names to indices."""
Expand Down Expand Up @@ -53,20 +55,20 @@ def get_variation(self, name, what, astype, updown):

def up(self, name, what, astype):
"""Return the "up" variation of this observable."""
return awkward.virtual(
self.get_variation,
args=(name, what, astype, "up"),
length=len(self),
parameters=self[what].layout.parameters if what != "weight" else None,
return self.get_variation(
name,
what,
astype,
"up",
)

def down(self, name, what, astype):
"""Return the "down" variation of this observable."""
return awkward.virtual(
self.get_variation,
args=(name, what, astype, "down"),
length=len(self),
parameters=self[what].layout.parameters if what != "weight" else None,
return self.get_variation(
name,
what,
astype,
"down",
)


Expand Down
2 changes: 1 addition & 1 deletion coffea/nanoevents/schemas/nanoaod.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _build_collections(self, field_names, input_contents):
output[name].setdefault("parameters", {})
output[name]["parameters"].update({"collection_name": name})

return output.keys(), output.values()
return list(output.keys()), list(output.values())

@property
def behavior(self):
Expand Down
11 changes: 4 additions & 7 deletions coffea/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,12 @@ def deprecate(exception, version, date=None):

# re-nest a record array into a ListArray
def awkward_rewrap(arr, like_what, gfunc):
behavior = awkward._util.behaviorof(like_what)
func = partial(gfunc, data=arr.layout)
layout = awkward.operations.convert.to_layout(like_what)
newlayout = awkward._util.recursively_apply(layout, func)
return awkward._util.wrap(newlayout, behavior=behavior)
return awkward.transform(func, like_what, behavior=like_what.behavior)


# we're gonna assume that the first record array we encounter is the flattened data
def rewrap_recordarray(layout, depth, data):
if isinstance(layout, awkward.layout.RecordArray):
return lambda: data
def rewrap_recordarray(layout, depth, data, **kwargs):
if isinstance(layout, awkward.contents.RecordArray):
return data
return None

0 comments on commit 8d7fe3d

Please sign in to comment.