Skip to content
This repository has been archived by the owner on Oct 25, 2022. It is now read-only.

Commit

Permalink
Merge pull request #18 from scikit-hep/issue-6
Browse files Browse the repository at this point in the history
`+` should unify buffer types
  • Loading branch information
jpivarski committed Apr 12, 2019
2 parents 663af8e + 4f90efc commit 90cf289
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
58 changes: 35 additions & 23 deletions python/aghast/interface.py
Expand Up @@ -815,21 +815,33 @@ def _remap(self, newshape, selfmap):
endianness=self.endianness,
dimension_order=self.dimension_order)

def _dontclobber(self, other):
return self.dtype != other.dtype or self.endianness != other.endianness or self.dimension_order != other.dimension_order

def _add(self, other, noclobber, op=numpy.add):
if noclobber:
if isinstance(self, InterpretedInlineBuffer) or isinstance(other, InterpretedInlineBuffer):
return InterpretedInlineBuffer(op(self.flatarray, other.flatarray).view(numpy.uint8),
filters=self.filters,
postfilter_slice=self.postfilter_slice,
dtype=self.dtype,
endianness=self.endianness,
dimension_order=self.dimension_order)

elif isinstance(self, InterpretedInlineFloat64Buffer) or isinstance(other, InterpretedInlineFloat64Buffer):
return InterpretedInlineFloat64Buffer(op(self.flatarray, other.flatarray).view(numpy.uint8))

elif isinstance(self, InterpretedInlineInt64Buffer) or isinstance(other, InterpretedInlineInt64Buffer):
return InterpretedInlineInt64Buffer(op(self.flatarray, other.flatarray).view(numpy.uint8))
out = op(self.flatarray, other.flatarray) # FIXME: what if they have different dimension_orders?
dtype, endianness = self.from_numpy_dtype(out.dtype)

if (dtype == InterpretedInlineFloat64Buffer.dtype.fget(None) and
endianness == InterpretedInlineFloat64Buffer.endianness.fget(None) and
self.dimension_order == InterpretedInlineFloat64Buffer.dimension_order.fget(None) and
other.dimension_order == InterpretedInlineFloat64Buffer.dimension_order.fget(None)):
return InterpretedInlineFloat64Buffer(out.view(numpy.uint8))

elif (dtype == InterpretedInlineInt64Buffer.dtype.fget(None) and
endianness == InterpretedInlineInt64Buffer.endianness.fget(None) and
self.dimension_order == InterpretedInlineInt64Buffer.dimension_order.fget(None) and
other.dimension_order == InterpretedInlineInt64Buffer.dimension_order.fget(None)):
return InterpretedInlineInt64Buffer(out.view(numpy.uint8))

elif isinstance(self, InterpretedInlineBuffer) or isinstance(other, InterpretedInlineBuffer):
return InterpretedInlineBuffer(out.view(numpy.uint8),
filters=None,
postfilter_slice=None,
dtype=dtype,
endianness=endianness,
dimension_order=self.fortran_order if numpy.isfortran(out) else self.c_order)

else:
raise AssertionError((type(self), type(other)))
Expand Down Expand Up @@ -1413,7 +1425,7 @@ def _dump(self, indent, width, end):
return _dumpline(self, args, indent, width, end)

def _add(self, other, noclobber, op=numpy.add):
if noclobber or self.external_source != self.memory or len(self.filters) != 0:
if noclobber or self.external_source != self.memory or len(self.filters) != 0 or self.postfilter_slice is not None:
return super(InterpretedExternalBuffer, self)._add(other, noclobber, op=op)

else:
Expand Down Expand Up @@ -1544,7 +1556,7 @@ def _dump(self, indent, width, end):
return _dumpline(self, args, indent, width, end)

def _add(self, other, noclobber):
return Moments(self.sumwxn._add(other.sumwxn, noclobber), self.n, weightpower=self.weightpower, filter=(None if self.filter is None else self.filter.detached()))
return Moments(self.sumwxn._add(other.sumwxn, noclobber or self.sumwxn._dontclobber(other.sumwxn)), self.n, weightpower=self.weightpower, filter=(None if self.filter is None else self.filter.detached()))

################################################# Extremes

Expand Down Expand Up @@ -1599,10 +1611,10 @@ def _dump(self, indent, width, end):
return _dumpline(self, args, indent, width, end)

def _min(self, other, noclobber):
return Extremes(self.values._add(other.values, noclobber, op=numpy.minimum), filter=(None if self.filter is None else self.filter.detached()))
return Extremes(self.values._add(other.values, noclobber or self.values._dontclobber(other.values), op=numpy.minimum), filter=(None if self.filter is None else self.filter.detached()))

def _max(self, other, noclobber):
return Extremes(self.values._add(other.values, noclobber, op=numpy.maximum), filter=(None if self.filter is None else self.filter.detached()))
return Extremes(self.values._add(other.values, noclobber or self.values._dontclobber(other.values), op=numpy.maximum), filter=(None if self.filter is None else self.filter.detached()))

################################################# Quantiles

Expand Down Expand Up @@ -1826,7 +1838,7 @@ def _dump(self, indent, width, end):

def _add(self, other, noclobber):
return Statistics(
moments=sum([[x._add(y, noclobber) for y in other.moments if x.n == y.n and x.weightpower == y.weightpower] for x in self.moments], []),
moments=sum([[x._add(y, noclobber or x._dontclobber(y)) for y in other.moments if x.n == y.n and x.weightpower == y.weightpower] for x in self.moments], []),
min=None if self.min is None or other.min is None else self.min._min(other.min, noclobber),
max=None if self.max is None or other.max is None else self.max._max(other.max, noclobber))

Expand Down Expand Up @@ -4963,7 +4975,7 @@ def _remap(self, newshape, selfmap):

def _add(self, other, noclobber):
assert isinstance(other, UnweightedCounts)
self.counts = self.counts._add(other.counts, noclobber)
self.counts = self.counts._add(other.counts, noclobber or self.counts._dontclobber(other.counts))
return self

@property
Expand Down Expand Up @@ -5063,15 +5075,15 @@ def _remap(self, newshape, selfmap):
def _add(self, other, noclobber):
assert isinstance(other, WeightedCounts)

self.sumw = self.sumw._add(other.sumw, noclobber)
self.sumw = self.sumw._add(other.sumw, noclobber or self.sumw._dontclobber(other.sumw))

if self.sumw2 is not None and other.sumw2 is not None:
self.sumw2 = self.sumw2._add(other.sumw2, noclobber)
self.sumw2 = self.sumw2._add(other.sumw2, noclobber or self.sumw2._dontclobber(other.sumw2))
else:
self.sumw2 = None

if self.unweighted is not None and other.unweighted is not None:
self.unweighted = self.unweighted._add(other.unweighted, noclobber)
self.unweighted = self.unweighted._add(other.unweighted, noclobber or self.unweighted._dontclobber(other.unweighted))
else:
self.unweighted = None

Expand Down Expand Up @@ -5912,7 +5924,7 @@ def _add(self, other, pairs, triples, noclobber):

for selfaxis, otheraxis, (binning, sm, om) in zip(self.axis, other.axis, triples[-len(self.axis):]):
selfaxis.binning = binning
selfaxis.statistics = [x._add(y, noclobber) for x, y in zip(selfaxis.statistics, otheraxis.statistics)]
selfaxis.statistics = [x._add(y, noclobber or x._dontclobber(y)) for x, y in zip(selfaxis.statistics, otheraxis.statistics)]

if self.counts is not selfcounts:
self.counts = selfcounts
Expand Down
2 changes: 1 addition & 1 deletion python/aghast/version.py
Expand Up @@ -4,7 +4,7 @@

import re

__version__ = "0.2.0"
__version__ = "0.2.1"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down

0 comments on commit 90cf289

Please sign in to comment.