Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create an optimization function which processes the graph and convert… #144

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
130 changes: 122 additions & 8 deletions src/pycel/excelcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,113 @@ def plot_graph(self, layout_type='spring_layout'):
nx.draw_networkx_labels(self.dep_graph, pos)
plt.show()

def optimized_compile(self, inputs, outputs, numba_optimize = False):
"""
This takes an idea from cell formulas to build a compiled function collapsing all the dependencies.
There are likely ways to generalize this but this can handle a lot of the optimization for a complex calculation.

Conceptually Numba and other JIT tools should allow opre optimization but I think the function calls rather
than local variables will cause this not to work as numba can't infer the types.

Inputs and Outputs must be lists of CELLs not ranges.
"""
#force the evaluation of all the output cells.
for cell in outputs:
res = self.evaluate(cell)

# Do a forward pass to figure out which cells are fed by these inputs.
stack = [self.cell_map[address] for address in inputs]
input_set = set(stack)
forward_set = set()
while stack:
next_cell = stack.pop()

if next_cell in forward_set:
continue
forward_set.add(next_cell)

for child_cell in self.dep_graph.successors(next_cell):
stack.append(child_cell)

forward_set = forward_set - input_set # remove the inputs from the forward set.

# rebuild the stack using outputs, we now will iterate backwards looking for nodes that
# exist on the backwards set.
for address in outputs:
cell = self.cell_map[address]
assert cell in forward_set # if something is an input it better feed the graph.
stack.append( cell )

backward_map = collections.defaultdict(lambda: 0)
code = ""
if numba_optimize:
code += "from numba import jit\n"
code += "@jit\n"
code += "def optimized_function(*args):\n"
cell_idx = {}
range_idx = {}

for input_idx, address in enumerate(inputs):
cell = self.cell_map[address]
cell_idx[cell] = len(cell_idx)
code += " C{idx} = args[{input_idx}]\n".format(idx = cell_idx[cell], input_idx = input_idx)

while stack:
next_cell = stack[-1]
stage = backward_map[next_cell]

if stage == 0:
# 0 is the iteration stage where we move backwards grabbing all the predicates of this item.
backward_map[next_cell] = 1
if next_cell in forward_set:
stack += list(self.dep_graph.predecessors(next_cell))
continue # return to processing next element
elif next_cell in input_set:
# inputs will already be captured.
backward_map[next_cell] = 2 # this is done no need for more.
else:
cell_idx[next_cell] = len(cell_idx)
code += " C{idx} = {value}\n".format(idx=cell_idx[next_cell], value=next_cell.value)
backward_map[next_cell] = 2 # this is done no need for more.

if stage == 1:
backward_map[next_cell] = 2
python_code = next_cell.python_code
if next_cell.address.is_range:
range_idx[next_cell] = len(range_idx)
range_tuple = ( "C{idx}".format(idx=cell_idx[x]) for x in self.dep_graph.predecessors(next_cell) )
code += " R{idx} = ({range})\n".format(idx=range_idx[next_cell], range=",".join(range_tuple) )

else:
for pred_cell in self.dep_graph.predecessors(next_cell):
if not pred_cell.address.is_range:
variable_name = "C{idx}".format(idx=cell_idx[pred_cell])
old_name = "_C_(\"{address}\")".format(address=pred_cell.address)
else:
variable_name = "R{idx}".format(idx=range_idx[pred_cell])
old_name = "_R_(\"{address}\")".format(address=pred_cell.address)
python_code = python_code.replace(old_name, variable_name)

cell_idx[next_cell] = len(cell_idx)
code += " C{idx} = {code}\n".format(idx=cell_idx[next_cell], code=python_code)

stack.pop()

code += " results = (\n"
for address in outputs:
code += " C{idx},\n".format(idx=cell_idx[self.cell_map[address]])
code += " )\n"
code += " return results\n"
code += ""

gbls = {
"sin": math.sin,
"sum_": sum,
}
exec(code, gbls) # exec will create optimized_function
return gbls["optimized_function"]


def set_value(self, address, value, set_as_range=False):
""" Set the value of one or more cells or ranges

Expand Down Expand Up @@ -456,15 +563,22 @@ def set_value(self, address, value, set_as_range=False):
cell_or_range.value = value

def _reset(self, cell):
if cell.needs_calc:
return
self.log.info(f"Resetting {cell.address}")
cell.value = None
cell_stack = [ cell ]
while cell_stack:
# use iterative version of recursion to prevent exceeding max recursive depth.
next_cell = cell_stack.pop()
if next_cell.needs_calc:
continue

self.log.info(f"Resetting {next_cell.address}")
next_cell.value = None

if next_cell in self.dep_graph:
for child_cell in self.dep_graph.successors(next_cell):
if child_cell.value is not None:
cell_stack.append(child_cell)


if cell in self.dep_graph:
for child_cell in self.dep_graph.successors(cell):
if child_cell.value is not None:
self._reset(child_cell)

def value_tree_str(self, address, indent=0):
iterative_eval_tracker.inc_iteration_number()
Expand Down
Binary file added tests/fixtures/deep_chain.xlsx
Binary file not shown.
51 changes: 51 additions & 0 deletions tests/test_excelcompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,3 +1132,54 @@ def test_evaluate_after_range_eval_error():
with pytest.raises(UnknownFunction):
excel_compiler.evaluate('A5')
assert excel_compiler.evaluate('A3') == 'hello'


def test_deep_chain(fixture_xls_copy):
compiler = ExcelCompiler(fixture_xls_copy('deep_chain.xlsx'))

fn1 = compiler.optimized_compile(['Switch!A1'], ['Switch!A12', 'Switch!A500'])
result1 = fn1(10)
assert result1['Switch!A10'] == 19
assert result1['Switch!A500'] == 509

fn2 = compiler.optimized_compile(['Switch!A1'], ['Switch!A14'])
result2 = fn2(20)
assert result2['Switch!A14'] == 33



def test_deep_chain_timing(fixture_xls_copy):
#disabled this test because it takes a long time and timing is likely inconsistent between machines. Not sure
#best way to deal with this.
import timeit

output_list = ['Switch!A20', 'Switch!B20', 'Switch!C20']

compiler = ExcelCompiler(fixture_xls_copy('deep_chain.xlsx'))
fn1 = compiler.optimized_compile(['Switch!A1'], output_list)

fn1_optimized = compiler.optimized_compile(['Switch!A1'], output_list, numba_optimize=True)
fn1_optimized(10) # very important need to call once before timeit. Timeit somehow breaks numba's caching.

def fn1_native(x):
#note only use A75 because A500 seems to take forever/cause memory issues.
compiler.set_value('Switch!A1', x)
return tuple(compiler.evaluate(x) for x in output_list)


result_compiled = fn1(10)
result_native = fn1_native(10)
result_optimized = fn1_optimized(10)
assert result_compiled == result_native
assert result_optimized == result_compiled

result = timeit.timeit("fn1_native(it); it = it + 1", "it = 0", number=10000, globals=locals())
print("\nNative time is %s" % result)

result = timeit.timeit("fn1(it); it = it + 1", "it = 0", number=10000, globals=locals())
print("Optimized time is %s" % result)

result = timeit.timeit("fn1_optimized(it); it = it + 1", "it = 0", number=10000, globals=locals())
print("Numba Optimized time is %s" % result)

print(fn1_optimized.inspect_types())