Skip to content

Commit

Permalink
Use RuleContainer to manage rules
Browse files Browse the repository at this point in the history
The rule container can be stored in the database, so that the input
program does never need to be parsed again. The Container has both the
AST representation as well as the string representation of the rules.

When the dependency graph is generated, the RuleContainer is created
from the input AST. A rule container can then be merged with another
to remove cycles/loops in the graph. A component of the viasp graph
then contains a RuleContainer.

This commit contains changes to the whole codebase to reflect the new
Container. Tests are updated.

Contributes: #34
  • Loading branch information
stephanzwicknagl committed Apr 22, 2024
1 parent 722d5a9 commit 0e92d93
Show file tree
Hide file tree
Showing 32 changed files with 413 additions and 515 deletions.
4 changes: 2 additions & 2 deletions backend/src/viasp/asp/justify.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .reify import ProgramAnalyzer, reify_recursion_transformation
from .recursion import RecursionReasoner
from .utils import insert_atoms_into_nodes, identify_reasons, calculate_spacing_factor
from ..shared.model import Node, Transformation, SymbolIdentifier
from ..shared.model import Node, RuleContainer, Transformation, SymbolIdentifier
from ..shared.simple_logging import info
from ..shared.util import pairwise, get_leafs_from_graph

Expand Down Expand Up @@ -148,7 +148,7 @@ def append_noops(result_graph: nx.DiGraph,
result_graph.add_edge(leaf,
noop_node,
transformation=Transformation(
next_transformation_id, tuple(pass_through)))
next_transformation_id, RuleContainer(ast=tuple(pass_through))))


def build_graph(wrapped_stable_models: List[List[str]],
Expand Down
79 changes: 31 additions & 48 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union, Sequence, Generator, Optional, cast
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union, Sequence, Optional, cast

import clingo
import networkx as nx
Expand Down Expand Up @@ -169,7 +169,9 @@ class ProgramAnalyzer(DependencyCollector, FilteredTransformer):
Receives a ASP program and finds it's dependencies within, can sort a program by it's dependencies.
"""

def __init__(self, dependency_graph: Optional[nx.DiGraph] = None):
def __init__(self, dependants: Optional[Dict[Tuple[str, int], Set[ast.Rule]]] = None, # type: ignore
conditions: Optional[Dict[Tuple[str, int], Set[ast.Rule]]] = None, # type: ignore
dependency_graph: Optional[nx.DiGraph] = None):
DependencyCollector.__init__(self, in_analyzer=True)
FilteredTransformer.__init__(self)
self.dependants: Dict[Tuple[str, int],
Expand Down Expand Up @@ -463,6 +465,8 @@ def add(statement):
lambda statement: self.visit(statement) and None)

def sort_program(self, program) -> List[Transformation]:
from viasp.server.database import GraphAccessor, get_or_create_encoding_id
GraphAccessor().save_program(program, get_or_create_encoding_id())
parse_string(program, lambda rule: self.visit(rule) and None)
sorted_program = self.primary_sort_program_by_dependencies()
return [
Expand All @@ -471,24 +475,20 @@ def sort_program(self, program) -> List[Transformation]:
]

def get_sort_program_and_graph(
self, program: str) -> Tuple[List[Tuple[AST, ...]], nx.DiGraph]:
self, program: str) -> Tuple[List[RuleContainer], nx.DiGraph]:
from viasp.server.database import GraphAccessor, get_or_create_encoding_id
GraphAccessor().save_program(program, get_or_create_encoding_id())
parse_string(program, lambda rule: self.visit(rule) and None)
sorted_programs = self.primary_sort_program_by_dependencies()
return sorted_programs, self.make_dependency_graph(
self.dependants, self.conditions)

def get_sorted_program(
self) -> Generator[List[Transformation], None, None]:
sorted_programs = self.primary_sort_program_by_dependencies()
for program in sorted_programs:
yield [Transformation(i, (prg)) for i, prg in enumerate(program)]

def get_primary_sort(self) -> List[Transformation]:
def get_sorted_program(self) -> List[Transformation]:
sorted_program = self.primary_sort_program_by_dependencies()
return self.make_transformations_from_sorted_program(sorted_program)

def make_transformations_from_sorted_program(
self, sorted_program: List[Tuple[ast.Rule]] # type: ignore
self, sorted_program: List[RuleContainer] # type: ignore
) -> List[Transformation]:
adjacency_index_mapping = self.get_index_mapping_for_adjacent_topological_sorts(
sorted_program)
Expand Down Expand Up @@ -516,32 +516,32 @@ def make_dependency_graph(

for deps in head_dependencies.values():
for dep in deps:
g.add_node(tuple([dep]))
g.add_node(RuleContainer(tuple([dep])))
for deps in body_dependencies.values():
for dep in deps:
g.add_node(tuple([dep]))
g.add_node(RuleContainer(tuple([dep])))

for head_signature, rules_with_head in head_dependencies.items():
dependent_rules = body_dependencies.get(head_signature, [])
for parent_rule in rules_with_head:
for dependent_rule in dependent_rules:
g.add_edge(tuple([parent_rule]), tuple([dependent_rule]))
g.add_edge(RuleContainer(tuple([parent_rule])), RuleContainer(tuple([dependent_rule])))

return g

def primary_sort_program_by_dependencies(
self) -> List[ast.Rule]: # type: ignore
deps = self.make_dependency_graph(self.dependants, self.conditions)
deps = merge_constraints(deps)
deps, _ = merge_cycles(deps)
deps, _ = remove_loops(deps)
self.dependency_graph = cast(nx.DiGraph, deps.copy())
sorted_program = topological_sort(deps, self.rules)
self) -> List[RuleContainer]:
graph = self.make_dependency_graph(self.dependants, self.conditions)
graph = merge_constraints(graph)
graph, _ = merge_cycles(graph)
graph, _ = remove_loops(graph)
self.dependency_graph = cast(nx.DiGraph, graph.copy())
sorted_program = topological_sort(graph, self.rules)
return sorted_program

def get_index_mapping_for_adjacent_topological_sorts(
self,
sorted_program: List[Tuple[ast.Rule]] # type: ignore
sorted_program: List[RuleContainer]
) -> Dict[int, List[int]]:
if self.dependency_graph is None:
raise ValueError(
Expand All @@ -551,16 +551,16 @@ def get_index_mapping_for_adjacent_topological_sorts(
self.dependency_graph, sorted_program)

def check_positive_recursion(self) -> Set[str]:
deps1 = self.make_dependency_graph(self.dependants,
positive_dependency_graph = self.make_dependency_graph(self.dependants,
self.positive_conditions)
deps1 = merge_constraints(deps1)
deps2, where1 = merge_cycles(deps1)
_, where2 = remove_loops(deps2)
positive_dependency_graph = merge_constraints(positive_dependency_graph)
positive_dependency_graph_withput_cycles, where1 = merge_cycles(positive_dependency_graph)
_, where2 = remove_loops(positive_dependency_graph_withput_cycles)

recursion_rules = set()
for t in where1.union(where2):
if self.should_include_recursive_set(t):
recursion_rules.add(hash_transformation_rules(t))
if any(not is_constraint(r) for r in t.ast):
recursion_rules.add(hash_transformation_rules(t.ast))
return recursion_rules

def should_include_recursive_set(self, recursive_tuple: Tuple[AST, ...]):
Expand Down Expand Up @@ -858,15 +858,7 @@ def transform(program: str, visitor=None, **kwargs):
def reify(transformation: Transformation, **kwargs):
visitor = ProgramReifier(transformation.id, **kwargs)
result: List[AST] = []
rules = transformation.rules
if any(isinstance(r, str) for r in rules):
rules_str = rules
rules = []
for rule in rules_str:
parse_string(
rule, lambda rule: rules.append(rule)
if rule.ast_type != ASTType.Program else None)
for rule in rules:
for rule in transformation.rules.ast:
result.extend(cast(Iterable[AST], visitor.visit(rule)))
return result

Expand Down Expand Up @@ -910,15 +902,6 @@ def reify_recursion_transformation(transformation: Transformation,
**kwargs) -> List[AST]:
visitor = ProgramReifierForRecursions(**kwargs)
result: List[AST] = []
rules = transformation.rules
if any(isinstance(r, str) for r in rules):
rules_str = rules
rules = []
for rule in rules_str:
parse_string(
rule, lambda rule: rules.append(rule)
if rule.ast_type != ASTType.Program else None)

for r in rules:
result.extend(cast(Iterable[AST], visitor.visit(r)))
for rule in transformation.rules.ast:
result.extend(cast(Iterable[AST], visitor.visit(rule)))
return result
2 changes: 1 addition & 1 deletion backend/src/viasp/asp/replayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..shared.event import Event, publish
from ..shared.model import ClingoMethodCall
from ..shared.simple_logging import warn
from ..shared.util import get_or_create_encoding_id
from ..server.database import get_or_create_encoding_id


def handler(cls):
Expand Down
52 changes: 15 additions & 37 deletions backend/src/viasp/asp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def is_constraint(rule: AST) -> bool:
def merge_constraints(g: nx.DiGraph) -> nx.DiGraph:
mapping = {}
constraints = frozenset([
ruleset for ruleset in g.nodes for rule in ruleset
ruleset for ruleset in g.nodes for rule in ruleset.ast
if is_constraint(rule)
])
if constraints:
Expand All @@ -25,9 +25,9 @@ def merge_constraints(g: nx.DiGraph) -> nx.DiGraph:
return nx.relabel_nodes(g, mapping)


def merge_cycles(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[Tuple[AST,...]]]:
def merge_cycles(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[RuleContainer]]:
mapping: Dict[AST, AST] = {}
merge_node: Tuple[AST, ...] = ()
merge_node: RuleContainer
where_recursion_happens = set()
for cycle in nx.algorithms.components.strongly_connected_components(g):
merge_node = merge_nodes(cycle)
Expand All @@ -39,16 +39,16 @@ def merge_cycles(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[Tuple[AST,...]]]:
return nx.relabel_nodes(g, mapping), frozenset(where_recursion_happens)


def merge_nodes(nodes: frozenset) -> Tuple[AST, ...]:
def merge_nodes(nodes: FrozenSet[RuleContainer]) -> RuleContainer:
old = set()
for x in nodes:
old.update(x)
return tuple(old)
old.update(x.ast)
return RuleContainer(tuple(old))


def remove_loops(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[Tuple[AST, ...]]]:
def remove_loops(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[RuleContainer]]:
remove_edges: List[Tuple[AST, AST]] = []
where_recursion_happens: Set[Tuple[AST]] = set()
where_recursion_happens: Set[RuleContainer] = set()
for edge in g.edges:
u, v = edge
if u == v:
Expand All @@ -61,28 +61,6 @@ def remove_loops(g: nx.DiGraph) -> Tuple[nx.DiGraph, FrozenSet[Tuple[AST, ...]]]
return g, frozenset(where_recursion_happens)


def rank_topological_sorts(all_sorts: Generator, rules: Sequence[AST]) -> List:
"""
Ranks all topological sorts by the number of rules that are in the same order as in the rules list.
The highest rank is the first element in the list.
:param all_sorts: List of all topological sorts
:param rules: List of rules
"""
ranked_sorts = []
all_sortss = list(all_sorts)
for sort in all_sortss:
rank = 0
sort_rules = [rule for frznst in sort for rule in frznst]
for i in range(len(sort_rules)):
rank -= (rules.index(sort_rules[i]) + 1) * (i + 1)
ranked_sorts.append((sort, rank))
# if len(ranked_sorts)>0:
# break
ranked_sorts.sort(key=lambda x: x[1])
return [x[0] for x in ranked_sorts]


def insert_atoms_into_nodes(path: List[Node]) -> None:
facts = path[0]
state = set(facts.diff)
Expand Down Expand Up @@ -198,7 +176,7 @@ def calculate_spacing_factor(g: nx.DiGraph) -> None:
children_next = []


def topological_sort(g: nx.DiGraph, rules: Sequence[Tuple[ast.Rule]]) -> List: # type: ignore
def topological_sort(g: nx.DiGraph, rules: Sequence[ast.Rule]) -> List: # type: ignore
""" Topological sort of the graph.
If the order is ambiguous, prefer the order of the rules.
Note: Rule = Node
Expand All @@ -215,7 +193,7 @@ def topological_sort(g: nx.DiGraph, rules: Sequence[Tuple[ast.Rule]]) -> List:
earliest_node_index = len(rules)
earliest_node = None
for node in no_incoming_edge:
for rule in node:
for rule in node.ast:
node_index = rules.index(rule)
if node_index < earliest_node_index:
earliest_node_index = node_index
Expand All @@ -238,14 +216,14 @@ def topological_sort(g: nx.DiGraph, rules: Sequence[Tuple[ast.Rule]]) -> List:

def find_index_mapping_for_adjacent_topological_sorts(
g: nx.DiGraph,
sorted_program: List[Tuple[ast.Rule]]) -> Dict[int, List[int]]: # type: ignore
sorted_program: List[RuleContainer]) -> Dict[int, List[int]]:
new_indices: Dict[int, List[int]] = {}
for i, rules_tuple in enumerate(sorted_program):
lower_bound = max([sorted_program.index(u) for u in g.predecessors(rules_tuple)]+[-1])
upper_bound = min([sorted_program.index(u) for u in g.successors(rules_tuple)]+[len(sorted_program)])
for i, rule_container in enumerate(sorted_program):
lower_bound = max([sorted_program.index(u) for u in g.predecessors(rule_container)]+[-1])
upper_bound = min([sorted_program.index(u) for u in g.successors(rule_container)]+[len(sorted_program)])
new_indices[i] = list(range(lower_bound+1,
upper_bound))
new_indices[i].remove(sorted_program.index(rules_tuple))
new_indices[i].remove(sorted_program.index(rule_container))
return new_indices


Expand Down
22 changes: 1 addition & 21 deletions backend/src/viasp/server/blueprints/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,28 +137,8 @@ def set_warnings():
return "ok"


def save_all_sorts(analyzer: ProgramAnalyzer,
batch_size: int = SORTGENERATION_BATCH_SIZE):
sorts = []
t_start = time()
for sorted_program in analyzer.get_sorted_program():
sorts.append(
(hash_from_sorted_transformations(sorted_program), sorted_program))
if len(sorts) >= batch_size:
save_many_sorts(sorts)
sorts = []
if time() - t_start > SORTGENERATION_TIMEOUT_SECONDS:
set_sortable(False)
clear_all_sorts()
break
if len(sorts) == 1:
set_sortable(False)
if sorts:
save_many_sorts(sorts)


def set_primary_sort(analyzer: ProgramAnalyzer):
primary_sort = analyzer.get_primary_sort()
primary_sort = analyzer.get_sorted_program()
primary_hash = hash_from_sorted_transformations(primary_sort)
register_adjacent_sorts(primary_sort, primary_hash)
try:
Expand Down
13 changes: 8 additions & 5 deletions backend/src/viasp/server/blueprints/dag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from ...shared.defaults import STATIC_PATH
from ...shared.model import Transformation, Node, Signature
from ...shared.util import get_start_node_from_graph, is_recursive, hash_from_sorted_transformations
from ...asp.utils import register_adjacent_sorts, recalculate_transformation_ids
from ...asp.utils import register_adjacent_sorts
from ...shared.io import StableModel
from ..database import load_recursive_transformations_hashes, save_graph, get_graph, clear_graph, set_current_graph, get_adjacent_graphs_hashes, get_current_graph_hash, get_current_sort, load_program, load_transformer, load_models, load_clingraph_names, is_sortable, save_sort, load_dependency_graph
from ..database import load_recursive_transformations_hashes, save_graph, get_graph, clear_graph, set_current_graph, get_current_graph_hash, get_current_sort, load_program, load_transformer, load_models, load_clingraph_names, is_sortable, save_sort, load_dependency_graph


bp = Blueprint("dag_api",
Expand Down Expand Up @@ -176,7 +176,7 @@ def get_possible_transformation_orders():
sorted_program_rules = [t.rules for t in get_current_sort()]
moved_item = sorted_program_rules.pop(moved_transformation["old_index"])
sorted_program_rules.insert(moved_transformation["new_index"], moved_item)
sorted_program_transformations = ProgramAnalyzer(load_dependency_graph()).make_transformations_from_sorted_program(sorted_program_rules)
sorted_program_transformations = ProgramAnalyzer(dependency_graph=load_dependency_graph()).make_transformations_from_sorted_program(sorted_program_rules)
hash = hash_from_sorted_transformations(sorted_program_transformations)
save_sort(hash, sorted_program_transformations)
register_adjacent_sorts(sorted_program_transformations, hash)
Expand Down Expand Up @@ -247,9 +247,12 @@ def entire_graph():
if request.json is None:
return jsonify({'error': 'Missing JSON in request'}), 400
data = request.json['data']
data = nx.node_link_graph(data) if type(data) == dict else data
hash = request.json['hash']
sort = request.json['sort']
sort = current_app.json.loads(sort) if type(sort) == str else sort
save_graph(data, hash, sort)
register_adjacent_sorts(sort, hash)
_ = set_current_graph(hash)
return jsonify({'message': 'ok'}), 200
elif request.method == "GET":
Expand Down Expand Up @@ -342,8 +345,8 @@ def search():
result.append(node)
for _, _, edge in graph.edges(data=True):
transformation = edge["transformation"]
if any(query in str(r) for r in
transformation.rules) and transformation not in result:
if any(query in rule for rule in
transformation.rules.str_) and transformation not in result:
result.append(transformation)
return jsonify(result[:10])
return jsonify([])
Expand Down
8 changes: 7 additions & 1 deletion backend/src/viasp/server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from clingo.ast import Transformer

from ..shared.defaults import PROGRAM_STORAGE_PATH, GRAPH_PATH
from ..shared.util import get_or_create_encoding_id
from ..shared.event import Event, subscribe
from ..shared.model import ClingoMethodCall, StableModel, Transformation, TransformerTransport, TransformationError

Expand Down Expand Up @@ -66,6 +65,13 @@ def get_pending(self) -> List[ClingoMethodCall]:
def mark_call_as_used(self, call: ClingoMethodCall):
self.used.add(call.uuid)

def get_or_create_encoding_id() -> str:
# TODO
# if 'encoding_id' not in session:
# session['encoding_id'] = uuid4().hex
# print(f"Returing encoding id {session['encoding_id']}", flush=True)
# return session['encoding_id']
return "0"

class GraphAccessor:

Expand Down

0 comments on commit 0e92d93

Please sign in to comment.