Skip to content

Commit

Permalink
Cleanup duplicate code
Browse files Browse the repository at this point in the history
Contributes: #34
  • Loading branch information
stephanzwicknagl committed Apr 16, 2024
1 parent 25df4f7 commit 722d5a9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 57 deletions.
33 changes: 4 additions & 29 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from viasp.shared.util import hash_transformation_rules

from .utils import find_index_mapping_for_adjacent_topological_sorts, is_constraint, merge_constraints, topological_sort, filter_body_aggregates, find_adjacent_topological_sorts
from .utils import find_index_mapping_for_adjacent_topological_sorts, is_constraint, merge_constraints, topological_sort, filter_body_aggregates
from ..asp.utils import merge_cycles, remove_loops
from viasp.asp.ast_types import (
SUPPORTED_TYPES,
Expand Down Expand Up @@ -464,10 +464,10 @@ def add(statement):

def sort_program(self, program) -> List[Transformation]:
parse_string(program, lambda rule: self.visit(rule) and None)
sorted_programs = self.sort_program_by_dependencies()
sorted_program = self.primary_sort_program_by_dependencies()
return [
Transformation(i, prg)
for i, prg in enumerate(next(sorted_programs))
for i, prg in enumerate(sorted_program)
]

def get_sort_program_and_graph(
Expand All @@ -479,7 +479,7 @@ def get_sort_program_and_graph(

def get_sorted_program(
self) -> Generator[List[Transformation], None, None]:
sorted_programs = self.sort_program_by_dependencies()
sorted_programs = self.primary_sort_program_by_dependencies()
for program in sorted_programs:
yield [Transformation(i, (prg)) for i, prg in enumerate(program)]

Expand Down Expand Up @@ -529,14 +529,6 @@ def make_dependency_graph(

return g

def sort_program_by_dependencies(self):
deps = self.make_dependency_graph(self.dependants, self.conditions)
deps = merge_constraints(deps)
deps, _ = merge_cycles(deps)
deps, _ = remove_loops(deps)
programs = nx.all_topological_sorts(deps)
return programs

def primary_sort_program_by_dependencies(
self) -> List[ast.Rule]: # type: ignore
deps = self.make_dependency_graph(self.dependants, self.conditions)
Expand All @@ -547,23 +539,6 @@ def primary_sort_program_by_dependencies(
sorted_program = topological_sort(deps, self.rules)
return sorted_program

def get_adjacent_topological_sorts(
self, sorted_program: List[Transformation]
) -> List[List[Transformation]]:
"""
Given a sorted program, return all other valid topological sorts that are achieved
by taking one Transformation and inserting it at another index.
"""
if self.dependency_graph is None:
raise ValueError(
"Dependency graph has not been created yet. Call primary_sort_program_by_dependencies first."
)
adjacent_sorts = find_adjacent_topological_sorts(
self.dependency_graph, [t.rules for t in sorted_program])
return [[
Transformation(i, prg) for i, prg in enumerate(adjacent_sort)
] for adjacent_sort in adjacent_sorts]

def get_index_mapping_for_adjacent_topological_sorts(
self,
sorted_program: List[Tuple[ast.Rule]] # type: ignore
Expand Down
21 changes: 0 additions & 21 deletions backend/src/viasp/asp/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Mostly graph utility functions."""
from re import U
import networkx as nx
from clingo import Symbol, ast
from clingo.ast import ASTType, AST
Expand Down Expand Up @@ -236,37 +235,17 @@ def topological_sort(g: nx.DiGraph, rules: Sequence[Tuple[ast.Rule]]) -> List:
raise Exception("Could not sort the graph.")
return sorted

def find_adjacent_topological_sorts(g: nx.DiGraph, sort: Sequence[Tuple[ast.Rule, ...]]) -> Set[Tuple[ast.Rule, ...]]: # type: ignore
adjacent_topological_sorts: Set[Tuple[ast.Rule, ...]] = set() # type: ignore
for transformation in g.nodes:
lower_bound = max([sort.index(u) for u in g.predecessors(transformation)]+[-1])
upper_bound = min([sort.index(u) for u in g.successors(transformation)]+[len(sort)])
new_indices = list(range(lower_bound+1,
upper_bound))
new_indices.remove(sort.index(transformation))
for new_index in new_indices:
new_sort = list(sort)
new_sort.remove(transformation)
new_sort.insert(new_index, transformation)
adjacent_topological_sorts.add(tuple(new_sort))
return adjacent_topological_sorts


def find_index_mapping_for_adjacent_topological_sorts(
g: nx.DiGraph,
sorted_program: List[Tuple[ast.Rule]]) -> Dict[int, List[int]]: # type: ignore
new_indices: Dict[int, List[int]] = {}
print(f"Recalculate...\n{sorted_program}\n", flush=True)
for i, rules_tuple in enumerate(sorted_program):
print(f"For transformation {rules_tuple}:\n\
max({[sorted_program.index(u) for u in g.predecessors(rules_tuple)]+[-1]})\n\
min({[sorted_program.index(u) for u in g.successors(rules_tuple)]+[len(sorted_program)]})", flush=True)
lower_bound = max([sorted_program.index(u) for u in g.predecessors(rules_tuple)]+[-1])
upper_bound = min([sorted_program.index(u) for u in g.successors(rules_tuple)]+[len(sorted_program)])
new_indices[i] = list(range(lower_bound+1,
upper_bound))
new_indices[i].remove(sorted_program.index(rules_tuple))
print(F"New indices: \n{new_indices}", flush=True)
return new_indices


Expand Down
2 changes: 0 additions & 2 deletions backend/src/viasp/server/blueprints/dag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,12 @@ def get_possible_transformation_orders():
}

sorted_program_rules = [t.rules for t in get_current_sort()]
print(f"OLD SORT: {sorted_program_rules}", flush=True)
moved_item = sorted_program_rules.pop(moved_transformation["old_index"])
sorted_program_rules.insert(moved_transformation["new_index"], moved_item)
sorted_program_transformations = ProgramAnalyzer(load_dependency_graph()).make_transformations_from_sorted_program(sorted_program_rules)
hash = hash_from_sorted_transformations(sorted_program_transformations)
save_sort(hash, sorted_program_transformations)
register_adjacent_sorts(sorted_program_transformations, hash)
print(f"NEW SORT WITH EDITED ATTRIBUTES: {sorted_program_transformations}", flush =True)
try:
set_current_graph(hash)
except ValueError:
Expand Down
18 changes: 13 additions & 5 deletions backend/src/viasp/shared/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from clingo import Symbol, ModelType
from clingo.ast import AST, Transformer, Rule
from viasp.shared.io import get_rules_from_input_program
from .util import DefaultMappingProxyType, hash_transformation_rules

@dataclass()
Expand Down Expand Up @@ -110,18 +111,25 @@ def __eq__(self, o):
def __repr__(self):
return f"Transformation(id={self.id}, rules={list(map(str,self.rules))}, adjacent_sort_indices={self.adjacent_sort_indices}, hash={self.hash})"

@dataclass(frozen=True)

@dataclass(frozen=False)
class RuleContainer:
rules: Tuple[Union[AST, str], ...]
rules_AST: Tuple[AST, ...] = field(default_factory=tuple, hash=True)
rules_str: Tuple[str, ...] = field(default_factory=tuple, hash=False)

def __post_init__(self):
if isinstance(self.rules_AST, AST) and len(self.rules_str) == 0:
self.rules_str = tuple(get_rules_from_input_program(self.rules_AST))

def __hash__(self):
return hash(self.rules)
return hash(self.rules_AST)

def __eq__(self, o):
return isinstance(o, type(self)) and self.rules == o.rules
return isinstance(o, type(self)) and self.rules_AST == o.rules_AST

def __repr__(self):
return str(self.rules)
return str(self.rules_str)


@dataclass(frozen=True)
class Signature:
Expand Down

0 comments on commit 722d5a9

Please sign in to comment.