Skip to content

Commit

Permalink
Add algorithm for finding adjacent orders
Browse files Browse the repository at this point in the history
This commit adds the first elements for an efficient (lazy) generation
of possible orders for transformations.

Contributes: #34
  • Loading branch information
stephanzwicknagl committed Apr 13, 2024
1 parent 1e152bb commit 1c463dc
Show file tree
Hide file tree
Showing 11 changed files with 470 additions and 179 deletions.
75 changes: 56 additions & 19 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from viasp.shared.util import hash_transformation_rules

from .utils import is_constraint, merge_constraints, topological_sort, filter_body_aggregates
from .utils import find_index_mapping_for_adjacent_topological_sorts, is_constraint, merge_constraints, topological_sort, filter_body_aggregates, find_adjacent_topological_sorts
from ..asp.utils import merge_cycles, remove_loops
from viasp.asp.ast_types import (
SUPPORTED_TYPES,
Expand All @@ -29,7 +29,9 @@ def is_fact(rule, dependencies):


def make_signature(literal: ast.Literal) -> Tuple[str, int]: # type: ignore
if literal.atom.ast_type in [ASTType.BodyAggregate, ASTType.BooleanConstant]:
if literal.atom.ast_type in [
ASTType.BodyAggregate, ASTType.BooleanConstant
]:
return literal, 0
unpacked = literal.atom.symbol
if unpacked.ast_type in [ASTType.Variable, ASTType.Function]:
Expand All @@ -47,6 +49,7 @@ def make_signature(literal: ast.Literal) -> Tuple[str, int]: # type: ignore
return (unpacked.name, len(unpacked.arguments))
raise ValueError(f"Could not make signature of {literal}")


def filter_body_arithmetic(elem: ast.Literal): # type: ignore
elem_ast_type = getattr(getattr(elem, "atom", ""), "ast_type", None)
return elem_ast_type not in ARITH_TYPES
Expand Down Expand Up @@ -99,7 +102,9 @@ def visit(self, ast: AST, *args: Any, **kwargs: Any) -> Union[AST, None]:
class DependencyCollector(Transformer):

def __init__(self, **kwargs):
self.compound_atoms_types: List = [ASTType.Aggregate, ASTType.BodyAggregate, ASTType.Comparison]
self.compound_atoms_types: List = [
ASTType.Aggregate, ASTType.BodyAggregate, ASTType.Comparison
]
self.in_analyzer = kwargs.get("in_analyzer", False)

def visit_ConditionalLiteral(
Expand Down Expand Up @@ -171,8 +176,9 @@ def __init__(self):
Set[ast.Rule]] = defaultdict(set) # type: ignore
self.conditions: Dict[Tuple[str, int],
Set[ast.Rule]] = defaultdict(set) # type: ignore
self.positive_conditions: Dict[Tuple[str, int],
Set[ast.Rule]] = defaultdict(set) # type: ignore
self.positive_conditions: Dict[Tuple[
str, int], Set[ast.Rule]] = defaultdict( # type: ignore
set)
self.rule2signatures = defaultdict(set)
self.facts: Set[Symbol] = set()
self.constants: Set[Symbol] = set()
Expand Down Expand Up @@ -395,7 +401,10 @@ def process_body(self, head, body, deps, in_analyzer=True):
if not len(deps) and len(body):
deps[head] = ([], [])
for _, (cond, pos_cond) in deps.items():
self.visit_sequence(body, conditions=cond, positive_conditions=pos_cond, in_analyzer=in_analyzer)
self.visit_sequence(body,
conditions=cond,
positive_conditions=pos_cond,
in_analyzer=in_analyzer)

def register_dependencies_and_append_rule(self, rule, deps):
self.register_rule_dependencies(rule, deps)
Expand Down Expand Up @@ -460,6 +469,13 @@ def sort_program(self, program) -> List[Transformation]:
for i, prg in enumerate(next(sorted_programs))
]

def get_sort_program_and_graph(
self, program: str) -> Tuple[List[Tuple[AST, ...]], nx.DiGraph]:
parse_string(program, lambda rule: self.visit(rule) and None)
sorted_programs = self.primary_sort_program_by_dependencies()
return sorted_programs, self.make_dependency_graph(
self.dependants, self.conditions)

def get_sorted_program(
self) -> Generator[List[Transformation], None, None]:
sorted_programs = self.sort_program_by_dependencies()
Expand All @@ -468,7 +484,9 @@ def get_sorted_program(

def get_primary_sort(self) -> List[Transformation]:
sorted_program = self.primary_sort_program_by_dependencies()
return [Transformation(i, prg) for i, prg in enumerate(sorted_program)]
adjacency_index_mapping = self.get_index_mapping_for_adjacent_topological_sorts(
sorted_program)
return [Transformation(i, prg, adjacency_index_mapping[i]) for i, prg in enumerate(sorted_program)]

def make_dependency_graph(
self,
Expand Down Expand Up @@ -496,8 +514,7 @@ def make_dependency_graph(
dependent_rules = body_dependencies.get(head_signature, [])
for parent_rule in rules_with_head:
for dependent_rule in dependent_rules:
g.add_edge(tuple([parent_rule]),
tuple([dependent_rule]))
g.add_edge(tuple([parent_rule]), tuple([dependent_rule]))

return g

Expand All @@ -515,8 +532,25 @@ def primary_sort_program_by_dependencies(
deps = merge_constraints(deps)
deps, _ = merge_cycles(deps)
deps, _ = remove_loops(deps)
programs = topological_sort(deps, self.rules)
return programs
self.dependency_graph = deps
sorted_program = topological_sort(deps, self.rules)
return sorted_program

def get_adjacent_topological_sorts(
self, sorted_program: List[Transformation]
) -> List[List[Transformation]]:
"""
Given a sorted program, return all other valid topological sorts that are achieved
by taking one Transformation and inserting it at another index.
"""
adjacent_sorts = find_adjacent_topological_sorts(
self.dependency_graph, [t.rules for t in sorted_program])
return [[
Transformation(i, prg) for i, prg in enumerate(adjacent_sort)
] for adjacent_sort in adjacent_sorts]

def get_index_mapping_for_adjacent_topological_sorts(self, sorted_program: List[Transformation]) -> Dict[int, List[int]]:
return find_index_mapping_for_adjacent_topological_sorts(self.dependency_graph, sorted_program)

def check_positive_recursion(self) -> Set[str]:
deps1 = self.make_dependency_graph(self.dependants,
Expand Down Expand Up @@ -609,7 +643,8 @@ def process_dependant_intervals(
self, loc: ast.Location,
dependant: Union[ast.Literal, ast.Function]): # type: ignore
if dependant.ast_type == ASTType.Function:
dependant = ast.Literal(loc, ast.Sign.NoSign, ast.SymbolicAtom(dependant))
dependant = ast.Literal(loc, ast.Sign.NoSign,
ast.SymbolicAtom(dependant))
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
variables = [
Expand Down Expand Up @@ -761,7 +796,7 @@ def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
return [rule]
if not deps:
# if it's a "simple head"
deps[rule.head] = ([],[])
deps[rule.head] = ([], [])
new_rules: List[ast.Rule] = [] # type: ignore
for dependant, (conditions, _) in deps.items():
dependant = self.process_dependant_intervals(loc, dependant)
Expand All @@ -771,7 +806,6 @@ def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
conditions=conditions,
)


self.replace_anon_variables(conditions)
new_head_s = self._nest_rule_head_in_h_with_explanation_tuple(
rule.location, dependant, conditions)
Expand Down Expand Up @@ -804,8 +838,6 @@ def make_loc_lit(self, loc: ast.Location) -> ast.Literal: # type: ignore
return ast.Literal(loc, ast.Sign.NoSign, loc_atm)




def register_rules(rule_or_list_of_rules, rulez):
if isinstance(rule_or_list_of_rules, list):
for rule in rule_or_list_of_rules:
Expand Down Expand Up @@ -833,7 +865,9 @@ def reify(transformation: Transformation, **kwargs):
rules_str = rules
rules = []
for rule in rules_str:
parse_string(rule, lambda rule: rules.append(rule) if rule.ast_type != ASTType.Program else None)
parse_string(
rule, lambda rule: rules.append(rule)
if rule.ast_type != ASTType.Program else None)
for rule in rules:
result.extend(cast(Iterable[AST], visitor.visit(rule)))
return result
Expand Down Expand Up @@ -874,15 +908,18 @@ def has_an_interval(literal: ast.Literal) -> bool: # type: ignore
return False


def reify_recursion_transformation(transformation: Transformation, **kwargs) -> List[AST]:
def reify_recursion_transformation(transformation: Transformation,
**kwargs) -> List[AST]:
visitor = ProgramReifierForRecursions(**kwargs)
result: List[AST] = []
rules = transformation.rules
if any(isinstance(r, str) for r in rules):
rules_str = rules
rules = []
for rule in rules_str:
parse_string(rule, lambda rule: rules.append(rule) if rule.ast_type != ASTType.Program else None)
parse_string(
rule, lambda rule: rules.append(rule)
if rule.ast_type != ASTType.Program else None)

for r in rules:
result.extend(cast(Iterable[AST], visitor.visit(r)))
Expand Down
34 changes: 33 additions & 1 deletion backend/src/viasp/asp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from clingo import Symbol, ast
from clingo.ast import ASTType, AST
from typing import Generator, List, Sequence, Tuple, Dict, Set, FrozenSet, Optional

from ..shared.simple_logging import warn
from ..shared.model import Node, SymbolIdentifier
from ..shared.model import Node, SymbolIdentifier, Transformation
from ..shared.util import pairwise, get_root_node_from_graph


Expand Down Expand Up @@ -234,6 +235,37 @@ def topological_sort(g: nx.DiGraph, rules: Sequence[ast.Rule]) -> List: # type:
raise Exception("Could not sort the graph.")
return sorted

def find_adjacent_topological_sorts(g: nx.DiGraph, sort: Sequence[Tuple[ast.Rule, ...]]) -> Set[Tuple[ast.Rule, ...]]: # type: ignore
adjacent_topological_sorts: Set[Tuple[ast.Rule, ...]] = set() # type: ignore
for transformation in g.nodes:
lower_bound = max([sort.index(u) for u in g.predecessors(transformation)]+[-1])
upper_bound = min([sort.index(u) for u in g.successors(transformation)]+[len(sort)])
new_indices = list(range(lower_bound+1,
upper_bound))
new_indices.remove(sort.index(transformation))
print(f"New indexes of {list(map(str,transformation))}: range({lower_bound}, {upper_bound}) ... {new_indices}")
for new_index in new_indices:
new_sort = list(sort)
new_sort.remove(transformation)
new_sort.insert(new_index, transformation)
adjacent_topological_sorts.add(tuple(new_sort))
return adjacent_topological_sorts


def find_index_mapping_for_adjacent_topological_sorts(
g: nx.DiGraph,
sorted_program: List[Transformation]) -> Dict[int, List[int]]:
new_indices: Dict[int, List[int]] = {}
sort = [t.rules for t in sorted_program]
for i, transformation in enumerate(g.nodes):
lower_bound = max([sort.index(u) for u in g.predecessors(transformation)]+[-1])
upper_bound = min([sort.index(u) for u in g.successors(transformation)]+[len(sort)])
new_indices[i] = list(range(lower_bound+1,
upper_bound))
new_indices[i].remove(sort.index(transformation))
print(f"New indexes of {list(map(str,transformation))}: range({lower_bound}, {upper_bound}) ... {new_indices}")
return new_indices


def filter_body_aggregates(element: AST):
aggregate_types = [
Expand Down
13 changes: 7 additions & 6 deletions backend/src/viasp/server/blueprints/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

from clingo import Control
from clingraph.orm import Factbase
from clingo.ast import Transformer
from clingraph.graphviz import compute_graphs, render

from .dag_api import generate_graph, set_current_graph, wrap_marked_models, \
load_program, load_transformer, load_models, \
load_clingraph_names
from ..database import CallCenter, get_database, save_recursive_transformations_hashes, set_models, clear_models, save_many_sorts, save_clingraph, clear_clingraph, save_transformer, save_warnings, clear_warnings, load_warnings, save_warnings, set_sortable, clear_all_sorts
from ..database import CallCenter, get_database, save_recursive_transformations_hashes, set_models, clear_models, save_many_sorts, save_sort, save_clingraph, clear_clingraph, save_transformer, save_warnings, clear_warnings, load_warnings, save_warnings, set_sortable, clear_all_sorts
from ...asp.reify import ProgramAnalyzer
from ...asp.relax import ProgramRelaxer, relax_constraints
from ...shared.model import ClingoMethodCall, StableModel, TransformerTransport
Expand Down Expand Up @@ -163,15 +162,18 @@ def set_primary_sort(analyzer: ProgramAnalyzer):
try:
_ = set_current_graph(primary_hash)
except KeyError:
save_many_sorts([((hash_from_sorted_transformations(primary_sort),
primary_sort))])
save_sort(hash_from_sorted_transformations(primary_sort),
primary_sort)
# TODO: also extract the adjacent sorts from the primary_sort and
# register those in the database
generate_graph()
except ValueError:
generate_graph()


def save_analyzer_values(analyzer: ProgramAnalyzer):
save_recursive_transformations_hashes(analyzer.check_positive_recursion())
## TODO: save attributes



Expand All @@ -185,9 +187,8 @@ def show_selected_models():
marked_models = wrap_marked_models(marked_models,
analyzer.get_conflict_free_showTerm())
if analyzer.will_work():
save_all_sorts(analyzer, batch_size=SORTGENERATION_BATCH_SIZE)
save_analyzer_values(analyzer)
set_primary_sort(analyzer)
save_analyzer_values(analyzer)

return "ok", 200

Expand Down
11 changes: 9 additions & 2 deletions backend/src/viasp/server/blueprints/dag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...shared.model import Transformation, Node, Signature
from ...shared.util import get_start_node_from_graph, is_recursive, hash_from_sorted_transformations
from ...shared.io import StableModel
from ..database import load_recursive_transformations_hashes, save_graph, get_graph, clear_graph, set_current_graph, get_all_sorts, get_current_sort, load_program, load_transformer, load_models, load_clingraph_names, is_sortable
from ..database import load_recursive_transformations_hashes, save_graph, get_graph, clear_graph, set_current_graph, get_adjacent_graphs_hashes, get_current_graph_hash, get_current_sort, load_program, load_transformer, load_models, load_clingraph_names, is_sortable, save_sort


bp = Blueprint("dag_api",
Expand Down Expand Up @@ -168,13 +168,20 @@ def get_possible_transformation_orders():
if request.json is None:
return jsonify({'error': 'Missing JSON in request'}), 400
hash = request.json["hash"]
sort = request.json["sort"] if "sort" in request.json else [] # could be as much as the switched indices
if sort:
# TODO: calculate the new adjacent_sort_indices for this sort
hash = hash_from_sorted_transformations(sort)
save_sort(hash, sort)
# TODO: also extract the adjacent sorts from the primary_sort and
# register those in the database
try:
set_current_graph(hash)
except ValueError:
generate_graph()
return "ok", 200
elif request.method == "GET":
sorts = get_all_sorts()
sorts = get_adjacent_graphs_hashes(get_current_graph_hash())
return jsonify(sorts)
raise NotImplementedError

Expand Down

0 comments on commit 1c463dc

Please sign in to comment.