Skip to content

Commit

Permalink
Improve recursion justifier program
Browse files Browse the repository at this point in the history
This commit improves the recursion justifier program for showing
correct recursive subgraphs for recursion with choice.

For example, the rule for a loop recursion with choice is as follows:
```
% input rule:
{b(X)} :- b(X-1), a(X).

% justifier rule:
h(n,b(X),(b((X-1)),a(X))) :-
	model(a(X));
	model(b((X-1)));
	@Derivable(b(X)) = 1,
	not model(b(X)).
```

The addition is the @Derivable(b(X)) = 1 condition, which is used to
check if the super-node contains the b(X) symbol. If it is in the super-
node, it can be derived. If it is not in the super-node, it cannot be
derived.

Resolves: #12
  • Loading branch information
stephanzwicknagl committed Apr 25, 2024
1 parent 19f2382 commit 9cab207
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def get_recursion_subgraph(

try:
RecursionReasoner(init=init,
derivables=supernode_symbols,
program=justification_program,
callback=h_syms.add,
conflict_free_h=conflict_free_h,
Expand Down
4 changes: 4 additions & 0 deletions backend/src/viasp/asp/recursion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@ class RecursionReasoner:
def __init__(self, **kwargs):
self.atoms = []
self.init = kwargs.pop("init", [])
self.derivables = kwargs.pop("derivables", [])
self.program = kwargs.pop("program", "")
self.register_h_symbols = kwargs.pop("callback", None)
self.conflict_free_h = kwargs.pop("conflict_free_h", "h")
self.conflict_free_n = kwargs.pop("conflict_free_n", "n")

def new(self):
return self.atoms

def derivable(self, atom):
return Number(1) if atom in self.derivables else Number(0)

def main(self):
control = Control()
Expand Down
10 changes: 5 additions & 5 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import clingo
import networkx as nx
from clingo import ast, Symbol
from clingo import ast, Symbol, Number
from clingo.ast import (
Transformer,
parse_string,
Expand Down Expand Up @@ -826,14 +826,14 @@ def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
conditions = [wrapper.visit(c) for c in conditions]

# Append dependant (wrapped, negated)
dep_fun = ast.Function(loc, f"{self.model_str}", [dependant], 0)
dep_fun = ast.Function(loc, self.model_str, [dependant], 0)
dep_atm = ast.SymbolicAtom(dep_fun)
conditions.append(ast.Literal(loc, ast.Sign.Negation, dep_atm))

# # Append dependant wrapped in derivable
# derivable_fun = ast.Function(loc, self.derivable_str, [dependant], 0)
# derivable_atm = ast.SymbolicAtom(derivable_fun)
# conditions.append(ast.Literal(loc, ast.Sign.NoSign, derivable_atm))
derivable_fun = ast.Function(loc, self.derivable_str, [dependant], 1)
derivable_comp = ast.Comparison(derivable_fun, [ast.Guard(5, ast.SymbolicTerm(loc, Number(1)))])
conditions.append(ast.Literal(loc, ast.Sign.NoSign, derivable_comp))

new_rules.extend([
ast.Rule(rule.location, new_head, conditions)
Expand Down

0 comments on commit 9cab207

Please sign in to comment.