Skip to content

Commit

Permalink
dry refactor for recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanzwicknagl committed Apr 7, 2023
1 parent ac77654 commit 13b846c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 100 deletions.
64 changes: 60 additions & 4 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from networkx import DiGraph

from .reify import ProgramAnalyzer
from .recursion import get_recursion_subgraph
from .recursion import RecursionReasoner
from .utils import insert_atoms_into_nodes, identify_reasons
from ..shared.model import Node, Transformation, SymbolIdentifier
from ..shared.simple_logging import info, warn
Expand Down Expand Up @@ -56,7 +56,7 @@ def get_facts(original_program) -> Collection[Symbol]:
return frozenset(facts)


def collect_h_symbols_and_create_nodes(h_symbols: Collection[Symbol], relevant_indices, pad: bool) -> List[Node]:
def collect_h_symbols_and_create_nodes(h_symbols: Collection[Symbol], relevant_indices, pad: bool, supernode_symbols: frozenset = frozenset([])) -> List[Node]:
tmp_symbol: Dict[int, List[SymbolIdentifier]] = defaultdict(list)
tmp_reason: Dict[int, Dict[Symbol, List[Symbol]]] = defaultdict(dict)
for sym in h_symbols:
Expand All @@ -65,13 +65,18 @@ def collect_h_symbols_and_create_nodes(h_symbols: Collection[Symbol], relevant_i
tmp_reason[rule_nr.number][str(symbol)] = reasons.arguments
for rule_nr in tmp_symbol.keys():
tmp_symbol[rule_nr] = set(tmp_symbol[rule_nr])
tmp_symbol[rule_nr] = map(SymbolIdentifier, tmp_symbol[rule_nr])
tmp_symbol[rule_nr] = map(lambda symbol: next(filter(
lambda supernode_symbol: supernode_symbol==symbol, supernode_symbols)) if
symbol in supernode_symbols else
SymbolIdentifier(symbol),tmp_symbol[rule_nr])
if pad:
h_symbols = [
Node(frozenset(tmp_symbol[rule_nr]), rule_nr, reason=tmp_reason[rule_nr]) if rule_nr in tmp_symbol else Node(frozenset(), rule_nr) for
rule_nr in relevant_indices]
else:
h_symbols = [Node(frozenset(tmp_symbol[rule_nr]), rule_nr, reason=frozenset(tmp_reason[rule_nr])) for rule_nr in tmp_symbol.keys()]
h_symbols = [
Node(frozenset(tmp_symbol[rule_nr]), rule_nr, reason=tmp_reason[rule_nr]) if rule_nr in tmp_symbol else Node(frozenset(), rule_nr)
for rule_nr in range(1, max(tmp_symbol.keys(), default=-1) + 1)]

return h_symbols

Expand Down Expand Up @@ -167,3 +172,54 @@ def save_model(model: Model) -> Collection[str]:
for part in model.symbols(atoms=True):
wrapped.append(f"{part}.")
return wrapped


def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
transformation: Union[AST, str], conflict_free_h: str,
get_conflict_free_model: callable = lambda s: "model",
get_conflict_free_iterindex: callable = lambda s: "n") -> Union[bool, nx.DiGraph]:
"""
Get a recursion explanation for the given facts and the recursive transformation.
Generate graph from explanation, sorted by the iteration step number.
:param facts: The symbols that were true before the recursive node.
:param supernode_symbols: The SymbolIdentifiers of the recursive node.
:param transformation: The recursive transformation. An ast object.
:param conflict_free_h: The name of the h predicate.
"""
init = [fact.symbol for fact in facts]
justification_program = ""
model_str: str = get_conflict_free_model()
n_str: str = get_conflict_free_iterindex()
for rule in transformation.rules:
# TODO: get reasons by transformer
tupleified = ",".join(list(map(str, rule.body)))
justification_head = f"{conflict_free_h}({n_str}, {rule.head}, ({tupleified}))"
justification_body = ",".join(
f"{model_str}({atom})" for atom in rule.body)
justification_body += f", not {model_str}({rule.head})"

justification_program += f"{justification_head} :- {justification_body}.\n"

justification_program += f"{model_str}(@new())."

h_syms = set()
try:
RecursionReasoner(init=init,
program=justification_program,
callback=h_syms.add,
conflict_free_h=conflict_free_h,
conflict_free_n=n_str).main()
except RuntimeError:
return False

h_syms = collect_h_symbols_and_create_nodes(h_syms, relevant_indices = [], pad = False, supernode_symbols = supernode_symbols)
# here: rule_nr is iteration number
h_syms.sort(key=lambda node: node.rule_nr)
h_syms.insert(0, Node(frozenset(facts), -1))
insert_atoms_into_nodes(h_syms)

reasoning_subgraph = nx.DiGraph()
for a, b in pairwise(h_syms[1:]):
reasoning_subgraph.add_edge(a, b)
return reasoning_subgraph if reasoning_subgraph.size() != 0 else False
97 changes: 1 addition & 96 deletions backend/src/viasp/asp/recursion.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,4 @@
from typing import List, Collection, Dict, Union
from collections import defaultdict

import networkx as nx
from clingo.script import enable_python
from clingo import Symbol, Number, Control
from clingo.ast import AST

from .utils import insert_atoms_into_nodes
from ..shared.model import Node, SymbolIdentifier
from ..shared.util import pairwise
from clingo import Number, Control


class RecursionReasoner:
Expand Down Expand Up @@ -40,88 +30,3 @@ def main(self):

for x in control.symbolic_atoms.by_signature(self.conflict_free_h, 3):
self.register_h_symbols(x.symbol)


def get_recursion_subgraph(facts: frozenset, supernode_symbols: frozenset,
transformation: Union[AST, str], conflict_free_h: str,
get_conflict_free_model: callable = lambda s: "model",
get_conflict_free_iterindex: callable = lambda s: "n") -> Union[bool, nx.DiGraph]:
"""
Get a recursion explanation for the given facts and the recursive transformation.
Generate graph from explanation, sorted by the iteration step number.
:param facts: The symbols that were true before the recursive node.
:param supernode_symbols: The SymbolIdentifiers of the recursive node.
:param transformation: The recursive transformation. An ast object.
:param conflict_free_h: The name of the h predicate.
"""
enable_python()
init = [fact.symbol for fact in facts]
justification_program = ""
model_str:str = get_conflict_free_model()
n_str:str = get_conflict_free_iterindex()
for i,rule in enumerate(transformation.rules):
# TODO: get reasons by transformer
tupleified = ",".join(list(map(str,rule.body)))
justification_head = f"{conflict_free_h}({n_str}, {rule.head}, ({tupleified}))"
justification_body = ",".join(f"{model_str}({atom})" for atom in rule.body)
justification_body += f", not {model_str}({rule.head})"

justification_program += f"{justification_head} :- {justification_body}.\n"

justification_program += f"{model_str}(@new())."

h_syms = set()
try:
RecursionReasoner(init = init,
program = justification_program,
callback = h_syms.add,
conflict_free_h = conflict_free_h,
conflict_free_n = n_str).main()
except RuntimeError:
return False

h_syms = collect_h_symbols_and_create_nodes(h_syms, supernode_symbols)
h_syms.sort(key=lambda node: node.rule_nr) # here: rule_nr is iteration number
h_syms.insert(0, Node(frozenset(facts), -1))
insert_atoms_into_nodes(h_syms)

reasoning_subgraph = nx.DiGraph()
for a, b in pairwise(h_syms[1:]):
reasoning_subgraph.add_edge(a, b)
return reasoning_subgraph if reasoning_subgraph.size() != 0 else False


def collect_h_symbols_and_create_nodes(h_symbols: Collection[Symbol], supernode_symbols: frozenset) -> List[Node]:
"""
Collect all h symbols and create nodes for each iteration step.
Adapted from the function for the same purpose on the main graph.
iter_nr is the reference to which iteration of the recursive node the symbol belongs to.
It is used similarly to the rule_nr in the main graph.
The SymbolIdentifiers are copied from the supernode_symbols to keep the UUIDs consistent.
:param h_symbols: The h symbols of the recursive node.
:param supernode_symbols: The supernode_symbols are the symbols of the recursive node.
They are used to keep the SymbolIdentifiers' UUIDs consistent.
"""
tmp_symbol: Dict[int, List[SymbolIdentifier]] = defaultdict(list)
tmp_reason: Dict[int, Dict[Symbol, List[Symbol]]] = defaultdict(dict)

for sym in h_symbols:
iter_nr, symbol, reasons = sym.arguments
tmp_symbol[iter_nr.number].append(symbol)
tmp_reason[iter_nr.number][str(symbol)] = reasons.arguments
for iter_nr in tmp_symbol.keys():
tmp_symbol[iter_nr] = set(tmp_symbol[iter_nr])
tmp_symbol[iter_nr] = map(lambda symbol: next(filter(
lambda supernode_symbol: supernode_symbol==symbol, supernode_symbols)) if
symbol in supernode_symbols else SymbolIdentifier(symbol),
tmp_symbol[iter_nr])

h_symbols = [
Node(frozenset(tmp_symbol[iter_nr]), iter_nr, reason=tmp_reason[iter_nr])
if iter_nr in tmp_symbol else Node(frozenset(), iter_nr)
for iter_nr in range(1, max(tmp_symbol.keys(), default=-1) + 1)]

return h_symbols

0 comments on commit 13b846c

Please sign in to comment.