Skip to content

Commit

Permalink
Pull changes from clingo-dl
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanzwicknagl committed Apr 1, 2023
1 parent 5ce5a0f commit 8225eba
Show file tree
Hide file tree
Showing 35 changed files with 1,836 additions and 118 deletions.
2 changes: 1 addition & 1 deletion backend/src/viasp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
load_program_file, load_program_string,
mark_from_clingo_model, mark_from_file, mark_from_string,
relax_constraints, show, unmark_from_clingo_model,
unmark_from_file, unmark_from_string)
unmark_from_file, unmark_from_string, register_transformer)
from .wrapper import Control, Control2
15 changes: 13 additions & 2 deletions backend/src/viasp/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from clingo import Control as InnerControl
from clingo import Model as clingo_Model
from clingo import ast
from clingo.ast import AST, ASTSequence, ASTType, Symbol
from clingo.ast import AST, ASTSequence, ASTType, Symbol, Transformer

from .shared.defaults import STDIN_TMP_STORAGE_PATH
from .shared.io import clingo_symbols_to_stable_model
Expand All @@ -36,7 +36,8 @@
"clear",
"show",
"relax_constraints",
"clingraph"
"clingraph",
"register_transformer",
]

SHOWCONNECTOR = None
Expand Down Expand Up @@ -278,6 +279,16 @@ def clingraph(viz_encoding, engine) -> None:
connector = _get_connector()
connector.clingraph(viz_encoding, engine)

def register_transformer(transformer: Transformer, imports: str = "", path: str = "") -> None:
r"""
Register a transformer to the backend. The program will be transformed
in the backend before further processing is made.
:param transformer: ``Transformer``
The transformer to register.
"""
connector = _get_connector()
connector.register_transformer(transformer, imports, path)

# ------------------------------------------------------------------------------
# Parse ASP facts from a string or files into a clingo model
Expand Down
13 changes: 5 additions & 8 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ def get_h_symbols_from_model(wrapped_stable_model: Iterable[Symbol],
rules_that_are_reasons_why = []
ctl = Control()
stringified = "".join(map(str, transformed_prg))
get_new_atoms_rule = f"{h}(I, H, G) :- {h}(I, H, G), not {h}(II,H,_) : II<I, {h}(II,_,_)."
new_head = f"_{h}"
get_new_atoms_rule = f"{new_head}(I, H, G) :- {h}(I, H, G), not {h}(II,H,_) : II<I, {h}(II,_,_)."
ctl.add("base", [], "".join(map(str, constants)))
ctl.add("base", [], "".join(map(stringify_fact, facts)))
ctl.add("base", [], stringified)
ctl.add("base", [], "".join(map(str, wrapped_stable_model)))
ctl.add("base", [], get_new_atoms_rule)
ctl.ground([("base", [])])
for x in ctl.symbolic_atoms.by_signature(h, 3):
for x in ctl.symbolic_atoms.by_signature(new_head, 3):
if x.symbol.arguments[1] in facts:
continue
rules_that_are_reasons_why.append(x.symbol)
Expand Down Expand Up @@ -106,12 +107,8 @@ def make_reason_path_from_facts_to_stable_model(wrapped_stable_model,
def join_paths_with_facts(paths: Collection[nx.DiGraph]) -> nx.DiGraph:
combined = nx.DiGraph()
for path in paths:
for node in path.nodes():
if node not in combined.nodes:
combined.add_node(node)
for u, v, r in path.edges(data=True):
if u in combined.nodes and v in combined.nodes:
combined.add_edge(u, v, transformation=r["transformation"])
combined.add_nodes_from(path.nodes(data=True))
combined.add_edges_from(path.edges(data=True))
return combined


Expand Down
179 changes: 149 additions & 30 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union, Sequence

import clingo
import networkx as nx
Expand All @@ -21,6 +21,8 @@ def make_signature(literal: clingo.ast.Literal) -> Tuple[str, int]:
if literal.atom.ast_type in [ast.ASTType.BodyAggregate]:
return literal, 0
unpacked = literal.atom.symbol
if unpacked.ast_type == ast.ASTType.Pool:
unpacked = unpacked.arguments[0]
return unpacked.name, len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0


Expand Down Expand Up @@ -69,28 +71,97 @@ def visit(self, ast: AST, *args: Any, **kwargs: Any) -> Union[AST, None]:

class DependencyCollector(Transformer):

def visit_Aggregate(self, aggregate, deps={}, names=set()):
conditional_literals = aggregate.elements
for elem in conditional_literals:
self.visit(elem, deps=deps)
return aggregate
def visit_Aggregate(self, aggregate: AST, **kwargs: Any) -> AST:
kwargs.update({"in_aggregate": True})
return aggregate.update(**self.visit_children(aggregate, **kwargs))

def visit_ConditionalLiteral(self, conditional_literal, deps={}, names=set()):
self.visit(conditional_literal.literal)
deps[conditional_literal.literal] = []
for condition in conditional_literal.condition:
deps[conditional_literal.literal].append(condition)
return conditional_literal
def visit_BodyAggregateElement(self, aggregate: AST, **kwargs: Any) -> AST:
# update flag
kwargs.update({"in_aggregate": True})

# collect conditions
conditions = kwargs.get("body_aggregate_elements", [])
conditions.extend(aggregate.condition)
return aggregate.update(**self.visit_children(aggregate, **kwargs))

class NameCollector(Transformer):
def visit_ConditionalLiteral(self, conditional_literal: AST, **kwargs: Any) -> AST:
deps = kwargs.get("deps", {})
new_body = kwargs.get("new_body", [])

def visit_Variable(self, variable, deps={}, names=set()):
deps[conditional_literal.literal] = []
for condition in conditional_literal.condition:
deps[conditional_literal.literal].append(condition)
new_body.extend(conditional_literal.condition)
return conditional_literal.update(**self.visit_children(conditional_literal, **kwargs))

def visit_Literal(self, literal: AST, **kwargs: Any) -> AST:
reasons = kwargs.get("reasons", [])
new_body = kwargs.get("new_body", [])

atom = literal.atom
if (literal.sign == ast.Sign.NoSign and
atom.ast_type == ast.ASTType.SymbolicAtom):
reasons.append(atom)
new_body.append(literal)
return literal.update(**self.visit_children(literal, **kwargs))

def visit_Variable(self, variable: AST, **kwargs: Any) -> AST:
# collect names
names = kwargs.get("names", set())
names.add(variable.name)

# rename if necessary
rename_variables = kwargs.get("rename_variables", False)
in_aggregate = kwargs.get("in_aggregate", False)
if rename_variables and in_aggregate:
return ast.Variable(variable.location, f"_{variable.name}")
return variable.update(**self.visit_children(variable, **kwargs))


class TheoryTransformer(Transformer):

def visit_TheoryAtom(self, atom: AST) -> AST:
term = atom.term
new_heads = []
if term.name == "diff":
content: List[Union[AST, List[AST]]] = []
self.visit_children(atom, theory_content=content)
for i in content:
inner_var = ast.Function(term.location, '', i, 0) if isinstance(i, List) else i
inner__ = ast.Variable(term.location, 'X') # TODO: get conflict free version?
outer_function = ast.Function(term.location, 'dl', [inner_var, inner__], 0)
outer_symatom = ast.SymbolicAtom(outer_function)
new_heads.append(ast.Literal(term.location, 0, outer_symatom))

return new_heads

def visit_TheoryAtomElement(self, atom_element, in_elem: bool = False, theory_content: List = []):
return atom_element.update(**self.visit_children(atom_element, True, theory_content))

def visit_TheoryGuard(self, guard, in_elem: bool = False, theory_content: List = []):
return guard.update(**self.visit_children(guard, in_elem, theory_content))

def visit_SymbolicTerm(self, term, in_elem: bool = False, theory_content: List = []):
if in_elem:
if term.symbol.type == clingo.SymbolType.Function:
theory_content.append(term)
return term

def visit_Variable(self, variable, in_elem: bool = False, theory_content: List = []):
if in_elem:
theory_content.append(variable)
return variable

def visit_TheorySequence(self, sequence, in_elem: bool = False, theory_content: List = []):
if in_elem:
variables: List[AST] = []
for s in sequence.terms:
variables.append(s) # TODO: actually s could be any other theory_term
theory_content.append(variables)
return sequence


class ProgramAnalyzer(DependencyCollector, FilteredTransformer, NameCollector):
class ProgramAnalyzer(DependencyCollector, FilteredTransformer):
"""
Receives a ASP program and finds it's dependencies within, can sort a program by it's dependencies.
"""
Expand All @@ -113,7 +184,7 @@ def __init__(self):
def _get_conflict_free_version_of_name(self, name: str) -> Collection[str]:
candidates = [name for name, _ in self.dependants.keys()]
candidates.extend([name for name, _ in self.conditions.keys()])
candidates.extend([fact.atom.symbol.name for fact in self.facts])
candidates.extend([getattr(getattr(getattr(fact,"atom"), "symbol"), "name", "") for fact in self.facts])
candidates.extend([name for name in self.names])
candidates = set(candidates)
current_best = name
Expand Down Expand Up @@ -166,9 +237,8 @@ def register_rule_dependencies(self, rule: Rule, deps: Dict[Literal, List[Litera
if (l.atom.symbol.name == u.atom.symbol.name and \
l.sign == ast.Sign.NoSign):
self.positive_conditions[u_sig].add(rule)


for v in filter(lambda symbol: symbol.atom.ast_type != ASTType.BooleanConstant, deps.keys()):

for v in filter(lambda symbol: symbol.atom.ast_type != ASTType.BooleanConstant if hasattr(symbol, "atom") else False, deps.keys()):
v_sig = make_signature(v)
self.dependants[v_sig].add(rule)

Expand All @@ -185,10 +255,17 @@ def visit_Rule(self, rule: Rule):
deps[rule.head] = []
for _, cond in deps.items():
cond.extend(filter(filter_body_arithmetic, rule.body))
cond.extend(self.get_body_aggregate_elements(rule.body))
self.register_symbolic_dependencies(deps)
self.names = self.names.union(names)
self.register_rule_dependencies(rule, deps)
self.rules.append(rule)

def get_body_aggregate_elements(self, body: Sequence[AST]) -> List[AST]:
body_aggregate_elements: List[AST] = []
for elem in body:
self.visit(elem, body_aggregate_elements=body_aggregate_elements)
return body_aggregate_elements


def visit_Minimize(self, minimize: Minimize):
Expand All @@ -201,8 +278,22 @@ def visit_Definition(self, definition):
self.constants.add(definition)
return definition

def add_program(self, program: str) -> None:
parse_string(program, lambda statement: self.visit(statement))
def add_program(self, program: str, registered_transformer: Transformer = None) -> None:
if registered_transformer is not None:
registered_visitor = registered_transformer()
new_program: List[AST] = []

def add(statement):
nonlocal new_program
if isinstance(statement, List):
new_program.extend(statement)
else:
new_program.append(statement)
parse_string(program, lambda statement: add(registered_visitor.visit(statement)))
for statement in new_program:
self.visit(statement)
else:
parse_string(program, lambda statement: self.visit(statement))

def sort_program(self, program) -> List[Transformation]:
parse_string(program, lambda rule: self.visit(rule))
Expand Down Expand Up @@ -253,7 +344,20 @@ def check_positive_recursion(self):
deps1 = merge_constraints(deps1)
deps2, where1 = merge_cycles(deps1)
_, where2 = remove_loops(deps2)
return where1.union(where2)
return {recursive_set for recursive_set in where1.union(where2)
if self.should_include_recursive_set(recursive_set)}

def should_include_recursive_set(self, recursive_set):
"""
Drop the set of integrity constraints from the recursive set.
"""
for rule in recursive_set:
head = getattr(rule, "head", None)
atom = getattr(head, "atom", None)
ast_type = getattr(atom, "ast_type", None)
if ast_type != ASTType.BooleanConstant:
return True
return False



Expand All @@ -267,7 +371,10 @@ def __init__(self, rule_nr=1, h="h", model="model", \
self.get_conflict_free_variable = get_conflict_free_variable

def _nest_rule_head_in_h_with_explanation_tuple(self, loc: ast.Location,
dependant: ast.Literal, body: List[ast.Literal]):
dependant: ast.Literal,
conditions: List[ast.Literal],
body: List[ast.Literal],
new_body: List[ast.Literal]):
"""
In: H :- B.
Out: h(0, H, pos_atoms(B)),
Expand All @@ -277,10 +384,14 @@ def _nest_rule_head_in_h_with_explanation_tuple(self, loc: ast.Location,
loc_atm = ast.SymbolicAtom(loc_fun)
loc_lit = ast.Literal(loc, ast.Sign.NoSign, loc_atm)
reasons = []
for literal in body:
if literal.sign == ast.Sign.NoSign and \
literal.atom.ast_type == ast.ASTType.SymbolicAtom:
for literal in conditions:
if literal.atom.ast_type == ast.ASTType.SymbolicAtom:
reasons.append(literal.atom)
for literal in body:
reason_literals = []
_ = self.visit(literal, reasons = reason_literals, new_body = new_body)
reasons.extend([r for r in reason_literals if r not in reasons])

reason_fun = ast.Function(loc, '', reasons, 0)
reason_lit = ast.Literal(loc, ast.Sign.NoSign, reason_fun)

Expand Down Expand Up @@ -319,13 +430,21 @@ def visit_Rule(self, rule: clingo.ast.Rule):
variables,
False))
dependant = ast.Literal(loc, ast.Sign.NoSign, symbol)

new_body = []
new_head_s = self._nest_rule_head_in_h_with_explanation_tuple(rule.location,
dependant,
rule.body)
# Add reified head to body
new_body = [dependant]
new_body.extend(rule.body)
conditions,
rule.body,
new_body)

new_body.insert(0, dependant)
for r in rule.body:
new_body.append(self.visit(r, rename_variables=True))
new_body.extend(conditions)
# Remove duplicates but preserve order
new_body = [x for i, x in enumerate(new_body) if x not in new_body[:i]]

new_rules.extend([Rule(rule.location, new_head, new_body) for new_head in new_head_s])

return new_rules
Expand Down
16 changes: 14 additions & 2 deletions backend/src/viasp/clingoApiClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import requests
from .shared.defaults import DEFAULT_BACKEND_URL
from .shared.io import DataclassJSONEncoder
from .shared.model import ClingoMethodCall, StableModel
from .shared.model import ClingoMethodCall, StableModel, TransformerTransport
from .shared.interfaces import ViaspClient
from .shared.simple_logging import log, Level, error

Expand Down Expand Up @@ -96,4 +96,16 @@ def clingraph(self, viz_encoding_path, engine):
if r.ok:
log(f"Cligraph visualization in progress.")
else:
error(f"Cligraph visualization failed [{r.status_code}] ({r.reason})")
error(f"Cligraph visualization failed [{r.status_code}] ({r.reason})")

def _register_transformer(self, transformer, imports, path):
serializable_transformer = TransformerTransport.merge(transformer, imports, path)
serialized = json.dumps(serializable_transformer,
cls=DataclassJSONEncoder)
r = requests.post(f"{self.backend_url}/control/add_transformer",
data=serialized,
headers={'Content-Type': 'application/json'})
if r.ok:
log(f"Transformer registered.")
else:
error(f"Registering transformer failed [{r.status_code}] ({r.reason})")

0 comments on commit 8225eba

Please sign in to comment.