Skip to content

Commit

Permalink
simplify reason collection
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanzwicknagl committed Apr 7, 2023
1 parent 897d8db commit b27790a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 30 deletions.
57 changes: 32 additions & 25 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ class DependencyCollector(Transformer):

def visit_Aggregate(self, aggregate: AST, **kwargs: Any) -> AST:
kwargs.update({"in_aggregate": True})
return aggregate.update(**self.visit_children(aggregate, **kwargs))
new_body = kwargs.get("new_body", [])

aggregate_update = aggregate.update(**self.visit_children(aggregate, **kwargs))
new_body.append(aggregate_update)
return aggregate_update

def visit_BodyAggregateElement(self, aggregate: AST, **kwargs: Any) -> AST:
# update flag
Expand All @@ -95,27 +99,34 @@ def visit_ConditionalLiteral(self, conditional_literal: AST, **kwargs: Any) -> A
return conditional_literal.update(**self.visit_children(conditional_literal, **kwargs))

def visit_Literal(self, literal: AST, **kwargs: Any) -> AST:
reasons = kwargs.get("reasons", [])
new_body = kwargs.get("new_body", [])
reasons: List[AST] = kwargs.get("reasons", [])
new_body: List[AST] = kwargs.get("new_body", [])

atom = literal.atom
literal_update = literal.update(**self.visit_children(literal, **kwargs))

atom: AST = literal.atom
if (literal.sign == ast.Sign.NoSign and
atom.ast_type == ast.ASTType.SymbolicAtom):
reasons.append(atom)
new_body.append(literal)
new_body.append(literal_update)
return literal.update(**self.visit_children(literal, **kwargs))

def visit_Variable(self, variable: AST, **kwargs: Any) -> AST:
# collect names
names = kwargs.get("names", set())
names: Set = kwargs.get("names", set())
names.add(variable.name)

# rename if necessary
rename_variables = kwargs.get("rename_variables", False)
in_aggregate = kwargs.get("in_aggregate", False)
rename_variables: bool = kwargs.get("rename_variables", False)
in_aggregate: bool = kwargs.get("in_aggregate", False)
if rename_variables and in_aggregate:
return ast.Variable(variable.location, f"_{variable.name}")
return variable.update(**self.visit_children(variable, **kwargs))

def visit_BooleanConstant(self, boolean_constant: AST, **kwargs: Any) -> AST:
new_body: List[AST] = kwargs.get("new_body", [])
new_body.append(boolean_constant)
return boolean_constant.update(**self.visit_children(boolean_constant, **kwargs))


class TheoryTransformer(Transformer):
Expand Down Expand Up @@ -373,8 +384,7 @@ def __init__(self, rule_nr=1, h="h", model="model", \
def _nest_rule_head_in_h_with_explanation_tuple(self, loc: ast.Location,
dependant: ast.Literal,
conditions: List[ast.Literal],
body: List[ast.Literal],
new_body: List[ast.Literal]):
reasons: List[ast.Literal]):
"""
In: H :- B.
Out: h(0, H, pos_atoms(B)),
Expand All @@ -383,15 +393,11 @@ def _nest_rule_head_in_h_with_explanation_tuple(self, loc: ast.Location,
loc_fun = ast.Function(loc, str(self.rule_nr), [], False)
loc_atm = ast.SymbolicAtom(loc_fun)
loc_lit = ast.Literal(loc, ast.Sign.NoSign, loc_atm)
reasons = []
for literal in conditions:
if literal.atom.ast_type == ast.ASTType.SymbolicAtom:
reasons.append(literal.atom)
for literal in body:
reason_literals = []
_ = self.visit(literal, reasons = reason_literals, new_body = new_body)
reasons.extend([r for r in reason_literals if r not in reasons])

reasons.reverse()
reasons = [r for i,r in enumerate(reasons) if r not in reasons[:i]]
reason_fun = ast.Function(loc, '', reasons, 0)
reason_lit = ast.Literal(loc, ast.Sign.NoSign, reason_fun)

Expand Down Expand Up @@ -431,20 +437,21 @@ def visit_Rule(self, rule: clingo.ast.Rule):
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)

new_body = []
new_head_s = self._nest_rule_head_in_h_with_explanation_tuple(rule.location,
dependant,
conditions,
rule.body,
new_body)
new_body: List[ast.Literal] = []
reason_literals: List[ast.Literal] = []
_ = self.visit_sequence(rule.body, reasons = reason_literals, new_body = new_body, rename_variables = False)
new_head_s = self._nest_rule_head_in_h_with_explanation_tuple(
rule.location,
dependant,
conditions,
reason_literals)

new_body.insert(0, dependant)
for r in rule.body:
new_body.append(self.visit(r, rename_variables=True))
new_body.extend(conditions)
# Remove duplicates but preserve order
new_body = [x for i, x in enumerate(new_body) if x not in new_body[:i]]

# rename variables inside body aggregates
new_body = self.visit_sequence(new_body, rename_variables=True)
new_rules.extend([Rule(rule.location, new_head, new_body) for new_head in new_head_s])

return new_rules
Expand Down
10 changes: 5 additions & 5 deletions backend/test/test_reification.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_normal_rule_without_negation_is_transformed_correctly():

def test_multiple_nested_variable_gets_transformed_correctly():
program = "x(1). y(1). l(x(X),y(Y)) :- x(X), y(Y)."
expected = "x(1). y(1). h(1, l(x(X),y(Y)), (x(X),y(Y))) :- l(x(X),y(Y)), x(X), y(Y)."
expected = "x(1). y(1). h(1, l(x(X),y(Y)), (y(Y),x(X))) :- l(x(X),y(Y)), x(X), y(Y)."
assertProgramEqual(transform(program), parse_program_to_ast(expected))


Expand Down Expand Up @@ -119,30 +119,30 @@ def test_head_aggregate_groups_is_transformed_correctly():
rule = "{a(X) : b(X), c(X); d(X) : e(X), X=1..3 }:- f(X)."
expected = """#program base.
h(1, d(X), (e(X),f(X))) :- d(X), f(X), e(X), X=1..3.
h(1, a(X), (b(X), c(X), f(X))) :- a(X), f(X), b(X), c(X)."""
h(1, a(X), (c(X),b(X),f(X))) :- a(X), f(X), b(X), c(X)."""
assertProgramEqual(transform(rule), parse_program_to_ast(expected))


def test_aggregate_choice_is_transformed_correctly():
rule = "1{a(X) : b(X), c(X); d(X) : e(X), X=1..3 }1:- f(X)."
expected = """#program base.
h(1, d(X), (e(X),f(X))) :- d(X), f(X), e(X), X=1..3.
h(1, a(X), (b(X), c(X), f(X))) :- a(X), f(X), b(X), c(X)."""
h(1, a(X), (c(X),b(X),f(X))) :- a(X), f(X), b(X), c(X)."""
assertProgramEqual(transform(rule), parse_program_to_ast(expected))


def test_multiple_conditional_groups_in_head():
rule = "1 #sum { X,Y : a(X,Y) : b(Y), c(X) ; X,Z : b(X,Z) : e(Z) } :- c(X)."
expected = """#program base.
h(1, a(X,Y), (b(Y), c(X))) :- a(X,Y), c(X), b(Y).
h(1, a(X,Y), (c(X),b(Y))) :- a(X,Y), c(X), b(Y).
h(1, b(X,Z), (e(Z), c(X))) :- b(X,Z), c(X), e(Z).
"""
assertProgramEqual(transform(rule), parse_program_to_ast(expected))


def test_multiple_aggregates_in_body():
rule = "s(Y) :- r(Y), 2 #sum{X : p(X,Y), q(X) } 7."
expected = "#program base. h(1, s(Y), (r(Y),p(X,Y), q(X))) :- s(Y), r(Y), p(X,Y), q(X), 2 #sum{_X : p(_X,_Y), q(_X) } 7."
expected = "#program base. h(1, s(Y), (q(X),p(X,Y),r(Y))) :- s(Y), r(Y), p(X,Y), q(X), 2 #sum{_X : p(_X,_Y), q(_X) } 7."
assertProgramEqual(transform(rule), parse_program_to_ast(expected))


Expand Down

0 comments on commit b27790a

Please sign in to comment.