Skip to content
This repository has been archived by the owner on Jun 21, 2022. It is now read-only.

Commit

Permalink
iterator for arrowed functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Oct 12, 2017
1 parent b2abc53 commit 38e0056
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
51 changes: 51 additions & 0 deletions uproot/_connect/toarrowed.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from arrowed.schema import Record
from arrowed.schema import Pointer

from uproot import iterator

def tostr(obj):
if hasattr(obj, "decode"):
return obj.decode("ascii")
Expand Down Expand Up @@ -209,3 +211,52 @@ def getarray(tree, schema):
_delayedraise(cls, err, trc)

return compiled(resolved, *args)

def compile(tree, function, paramtypes={}, env={}, numba={"nopython": True, "nogil": True}, fcncache=None, schema=None, debug=False):
if schema is None:
schema = _schema(tree)
dtypes = dict((b.name, b.dtype) for b in tree.allbranches if getattr(b, "dtype", None) is not None)
return ArrowedFunction(schema, dtypes, schema.compile(function, paramtypes=paramtypes, env=env, numba=numba, fcncache=fcncache, debug=debug))

class ArrowedFunction(object):
def __init__(self, schema, dtypes, compiled):
self._schema = schema
self._dtypes = dtypes
self._compiled = compiled

def run(self, entries, path, treepath, args=(), memmap=True, executor=None, datacache=None, reportentries=False):
if datacache is not None:
raise NotImplementedError

entriesarray = [entries]

def getarray(arrays, schema):
branchname = schema.name
if branchname is None:
array = numpy.array(entriesarray, dtype=numpy.int32)
else:
array = arrays[branchname]
return array

accessor = self._schema.accessedby(getarray)

branchdtypes = {}
for parameter in self._compiled.parameters.transformed:
for symbol, (member, attr) in parameter.sym2obj.items():
branchname = member.name
if branchname is not None:
branchdtypes[branchname] = self._dtypes[branchname].newbyteorder("=")

for entrystart, entryend, arrays in iterator(entries, path, treepath, branchdtypes, memmap=memmap, executor=executor, reportentries=True):
entriesarray = [entryend - entrystart]

resolved = accessor.resolved(arrays, lazy=True)
for parameter in self._compiled.parameters.transformed:
for symbol, (member, attr) in parameter.sym2obj.items():
resolved.findbybase(member).get(attr)

out = self._compiled(resolved, *args)
if reportentries:
yield entrystart, entryend, out
else:
yield out
4 changes: 4 additions & 0 deletions uproot/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def run(*args, **kwds):
return uproot._connect.toarrowed.run(self, *args, **kwds)
connector.run = run

def compile(*args, **kwds):
return uproot._connect.toarrowed.compile(self, *args, **kwds)
connector.compile = compile

return connector

@property
Expand Down
2 changes: 1 addition & 1 deletion uproot/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import re

__version__ = "1.5.2"
__version__ = "1.5.3"
version = __version__
version_info = tuple(re.split(r"[-\.]", __version__))

Expand Down

0 comments on commit 38e0056

Please sign in to comment.