Skip to content

Commit

Permalink
Merge pull request #43 from potassco/BUG_All-Ast-#40
Browse files Browse the repository at this point in the history
Bug all ast #40
  • Loading branch information
stephanzwicknagl committed Dec 21, 2023
2 parents cc3d270 + 748c8ee commit 8d186c7
Show file tree
Hide file tree
Showing 23 changed files with 551 additions and 318 deletions.
2 changes: 2 additions & 0 deletions backend/src/viasp/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import textwrap
import webbrowser
import importlib.metadata
from clingo.script import enable_python

from viasp import Control
from viasp.server import startup
Expand Down Expand Up @@ -84,6 +85,7 @@ def start():
options = [str(models)]

backend_url = f"{DEFAULT_BACKEND_PROTOCOL}://{host}:{port}"
enable_python()
ctl = Control(options, viasp_backend_url=backend_url)
for path in paths:
ctl.load(path)
Expand Down
5 changes: 1 addition & 4 deletions backend/src/viasp/asp/ast_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
ASTType.Literal, ASTType.HeadAggregate, ASTType.HeadAggregateElement, ASTType.BodyAggregate, ASTType.Aggregate,
ASTType.ConditionalLiteral, ASTType.Guard, ASTType.Comparison, ASTType.SymbolicAtom, ASTType.Function,
ASTType.BodyAggregateElement, ASTType.BooleanConstant, ASTType.SymbolicAtom, ASTType.Variable, ASTType.SymbolicTerm,
ASTType.Interval, ASTType.UnaryOperation, ASTType.BinaryOperation
ASTType.Interval, ASTType.UnaryOperation, ASTType.BinaryOperation, ASTType.Defined, ASTType.External, ASTType.ProjectAtom, ASTType.ProjectSignature, ASTType.ShowTerm
}

UNSUPPORTED_TYPES = {
ASTType.Disjunction,
ASTType.TheorySequence, ASTType.TheoryFunction, ASTType.TheoryUnparsedTermElement, ASTType.TheoryUnparsedTerm,
ASTType.TheoryGuard, ASTType.TheoryAtomElement, ASTType.TheoryAtom, ASTType.TheoryOperatorDefinition,
ASTType.TheoryTermDefinition, ASTType.TheoryGuardDefinition, ASTType.TheoryAtomDefinition,
}

def make_unknown_AST_enum_types():
Expand Down
50 changes: 41 additions & 9 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections import defaultdict
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union, Sequence
from typing import Dict, List, Tuple, Iterable, Set, Collection, Any, Union, Sequence, cast

import clingo
import networkx as nx
from clingo import ast, Symbol
from clingo.ast import Transformer, parse_string, Rule, ASTType, AST, Literal, Minimize, Disjunction

from .utils import is_constraint, merge_constraints, topological_sort
from .utils import is_constraint, merge_constraints, topological_sort, place_ast_at_location
from ..asp.utils import merge_cycles, remove_loops
from viasp.asp.ast_types import SUPPORTED_TYPES, ARITH_TYPES, UNSUPPORTED_TYPES, UNKNOWN_TYPES
from ..shared.model import Transformation, TransformationError, FailedReason
Expand All @@ -21,7 +21,8 @@ def make_signature(literal: clingo.ast.Literal) -> Tuple[str, int]:
if literal.atom.ast_type in [ast.ASTType.BodyAggregate]:
return literal, 0
unpacked = literal.atom.symbol
if unpacked.ast_type == ast.ASTType.Pool:
if (hasattr(unpacked, "ast_type") and
unpacked.ast_type == ASTType.Pool):
unpacked = unpacked.arguments[0]
return unpacked.name, len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0

Expand Down Expand Up @@ -199,8 +200,8 @@ def __init__(self):
super().__init__()
# TODO: self.dependencies can go?
self.dependencies = nx.DiGraph()
self.dependants: Dict[Tuple[str, int], Set[Rule]] = defaultdict(set)
self.conditions: Dict[Tuple[str, int], Set[Rule]] = defaultdict(set)
self.dependants: Dict[Tuple[str, int], Set[AST]] = defaultdict(set)
self.conditions: Dict[Tuple[str, int], Set[AST]] = defaultdict(set)
self.positive_conditions: Dict[Tuple[str, int], Set[Rule]] = defaultdict(set)
self.rule2signatures = defaultdict(set)
self.facts: Set[Symbol] = set()
Expand Down Expand Up @@ -280,7 +281,7 @@ def register_rule_dependencies(self, rule: Rule, deps: Dict[Literal, List[Litera
):
self.positive_conditions[u_sig].add(rule)

for v in filter(lambda symbol: symbol.atom.ast_type != ASTType.BooleanConstant if hasattr(symbol, "atom") else False, deps.keys()):
for v in filter(lambda symbol: symbol.atom.ast_type != ASTType.BooleanConstant if (hasattr(symbol, "atom") and hasattr(symbol.atom, "ast_type")) else False, deps.keys()):
v_sig = make_signature(v)
self.dependants[v_sig].add(rule)

Expand Down Expand Up @@ -309,16 +310,47 @@ def get_body_aggregate_elements(self, body: Sequence[AST]) -> List[AST]:
self.visit(elem, body_aggregate_elements=body_aggregate_elements)
return body_aggregate_elements

def visit_ShowTerm(self, showTerm: AST):
if (hasattr(showTerm, "location") and isinstance(showTerm.location, ast.Location)
and hasattr(showTerm, "term") and isinstance(showTerm.term, AST)
and hasattr(showTerm, "body") and isinstance(showTerm.body, Sequence)
and all(isinstance(elem, AST) for elem in showTerm.body)):
new_head = ast.Literal(
showTerm.location,
ast.Sign.NoSign,
ast.SymbolicAtom(
showTerm.term
)
)
self.visit(
ast.Rule(
showTerm.location,
new_head,
cast(Sequence, showTerm.body))
)
else:
print(f"Plan B for ShowTerm: {showTerm}", flush=True)
new_rule = ast.Rule(
cast(ast.Location, showTerm.location),
ast.Literal(
cast(ast.Location, showTerm.location),
ast.Sign.NoSign,
cast(AST, showTerm.term)),
cast(Sequence, showTerm.body))
parse_string(place_ast_at_location(new_rule), lambda rule: self.visit(rule))


def visit_Minimize(self, minimize: Minimize):
deps = defaultdict(list)
self.pass_through.add(minimize)

return minimize

def visit_Defined(self, defined: AST):
self.pass_through.add(defined)

def visit_Definition(self, definition):
self.constants.add(definition)
return definition

def add_program(self, program: str, registered_transformer: Transformer = None) -> None:
if registered_transformer is not None:
Expand Down Expand Up @@ -346,8 +378,8 @@ def get_sorted_program(self) -> List[Transformation]:
sorted_program = self.sort_program_by_dependencies()
return [Transformation(i, prg) for i, prg in enumerate(sorted_program)]

def make_dependency_graph(self, head_dependencies: Dict[Tuple[str, int], Iterable[clingo.ast.AST]],
body_dependencies: Dict[Tuple[str, int], Iterable[clingo.ast.AST]]) -> nx.DiGraph:
def make_dependency_graph(self, head_dependencies: Dict[Tuple[str, int], Set[AST]],
body_dependencies: Dict[Tuple[str, int], Set[AST]]) -> nx.DiGraph:
"""
We draw a dependency graph based on which rule head contains which literals.
That way we know, that in order to have a rule r with a body containing literal l, all rules that have l in their
Expand Down
16 changes: 14 additions & 2 deletions backend/src/viasp/asp/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Mostly graph utility functions."""
import networkx as nx
from clingo import Symbol
from clingo.ast import Rule, ASTType
from clingo.ast import Rule, ASTType, AST, Location
from typing import List, Sequence
from ..shared.simple_logging import warn
from ..shared.model import Node, SymbolIdentifier
Expand Down Expand Up @@ -165,4 +165,16 @@ def get_identifiable_reason(g: nx.DiGraph, v: Node, r: Symbol,
# stop criterion: v is the root node and there is no super_graph
warn(f"An explanation could not be made")
return None


def place_ast_at_location(ast: AST) -> str:
"""
Generates a string where ast is located at the
proper location defined in the given AST.
"""
ans = ""
if hasattr(ast,"location") and ast.location != None and isinstance(ast.location, Location):
for i in range(ast.location.begin.line-1):
ans += "\n"
for i in range(ast.location.begin.column-1):
ans += " "
return ans + str(ast)
2 changes: 2 additions & 0 deletions backend/src/viasp/server/blueprints/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def wrap_marked_models(marked_models: Iterable[StableModel]):
wrapped = []
for part in model.atoms:
wrapped.append(f"{part}.")
for part in model.terms:
wrapped.append(f"{part}.")
result.append(wrapped)
return result

Expand Down
114 changes: 68 additions & 46 deletions backend/src/viasp/shared/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import IntEnum
from json import JSONEncoder, JSONDecoder
from dataclasses import is_dataclass
from typing import Any, Union, Collection, Iterable, Dict, Sequence
from typing import Any, Union, Collection, Iterable, Dict, Sequence, cast
from pathlib import PosixPath
from uuid import UUID
import os
Expand Down Expand Up @@ -39,6 +39,12 @@ def object_hook(obj):
return clingo.Function(**obj)
elif t == "Number":
return clingo.Number(**obj)
elif t == "String":
return clingo.String(**obj)
elif t == "Infimum":
return clingo.Infimum
elif t == "Supremum":
return clingo.Supremum
elif t == "Node":
obj['atoms'] = frozenset(obj['atoms'])
obj['diff'] = frozenset(obj['diff'])
Expand All @@ -51,6 +57,8 @@ def object_hook(obj):
return nx.node_link_graph(obj["_graph"])
elif t == "StableModel":
return StableModel(**obj)
elif t == "ModelType":
return ModelType.StableModel
elif t == "ClingoMethodCall":
return ClingoMethodCall(**obj)
elif t == "SymbolIdentifier":
Expand Down Expand Up @@ -146,7 +154,7 @@ def encode_object(o):
elif isinstance(o, PosixPath):
return str(o)
elif isinstance(o, ModelType):
return {"__enum__": str(o)}
return {"_type": "ModelType", "__enum__": str(o)}
elif isinstance(o, Symbol):
x = symbol_to_dict(o)
return x
Expand Down Expand Up @@ -179,12 +187,18 @@ def model_to_dict(model: clingo_Model) -> dict:


def clingo_model_to_stable_model(model: clingo_Model) -> StableModel:
return StableModel(model.cost, model.optimality_proven, model.type, encode_object(model.symbols(atoms=True)),
encode_object(model.symbols(terms=True)), encode_object(model.symbols(shown=True)),
encode_object(model.symbols(theory=True)))
return StableModel(
model.cost,
model.optimality_proven,
model.type,
cast(Collection[Symbol], encode_object(model.symbols(atoms=True))),
cast(Collection[Symbol], encode_object(model.symbols(terms=True))),
cast(Collection[Symbol], encode_object(model.symbols(shown=True))),
cast(Collection[Symbol], encode_object(model.symbols(theory=True))),
)

def clingo_symbols_to_stable_model(atoms: Iterable[Symbol]) -> StableModel:
return StableModel(atoms=encode_object(atoms))
return StableModel(atoms=cast(Collection[Symbol], encode_object(atoms)))

def symbol_to_dict(symbol: clingo.Symbol) -> dict:
symbol_dict = {}
Expand All @@ -196,49 +210,57 @@ def symbol_to_dict(symbol: clingo.Symbol) -> dict:
elif symbol.type == clingo.SymbolType.Number:
symbol_dict["number"] = symbol.number
symbol_dict["_type"] = "Number"
elif symbol.type == clingo.SymbolType.String:
symbol_dict["string"] = symbol.string
symbol_dict["_type"] = "String"
elif symbol.type == clingo.SymbolType.Infimum:
symbol_dict["_type"] = "Infimum"
elif symbol.type == clingo.SymbolType.Supremum:
symbol_dict["_type"] = "Supremum"
return symbol_dict


class viasp_ModelType(IntEnum):
"""
Enumeration of the different types of models.
"""
BraveConsequences = clingo_model_type_brave_consequences
"""
The model stores the set of brave consequences.
"""
CautiousConsequences = clingo_model_type_cautious_consequences
"""
The model stores the set of cautious consequences.
"""
StableModel = clingo_model_type_stable_model
"""
The model captures a stable model.
"""

@classmethod
def from_clingo_ModelType(cls, clingo_ModelType: ModelType):
if clingo_ModelType.name == cls.BraveConsequences.name:
return cls.BraveConsequences
elif clingo_ModelType.name == cls.StableModel.name:
return cls.StableModel
else:
return cls.CautiousConsequences


class ClingoModelEncoder(JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, clingo_Model):
x = model_to_dict(o)
return x
elif isinstance(o, ModelType):
if o in [ModelType.CautiousConsequences, ModelType.BraveConsequences, ModelType.StableModel]:
return {"__enum__": str(o)}
return super().default(o)
elif isinstance(o, Symbol):
x = symbol_to_dict(o)
return x
return super().default(o)
# Legacy: To be deleted in Version 3.0
# class viasp_ModelType(IntEnum):
# """
# Enumeration of the different types of models.
# """
# BraveConsequences = clingo_model_type_brave_consequences
# """
# The model stores the set of brave consequences.
# """
# CautiousConsequences = clingo_model_type_cautious_consequences
# """
# The model stores the set of cautious consequences.
# """
# StableModel = clingo_model_type_stable_model
# """
# The model captures a stable model.
# """

# @classmethod
# def from_clingo_ModelType(cls, clingo_ModelType: ModelType):
# if clingo_ModelType.name == cls.BraveConsequences.name:
# return cls.BraveConsequences
# elif clingo_ModelType.name == cls.StableModel.name:
# return cls.StableModel
# else:
# return cls.CautiousConsequences


# class ClingoModelEncoder(JSONEncoder):
# def default(self, o: Any) -> Any:
# if isinstance(o, clingo_Model):
# x = model_to_dict(o)
# return x
# elif isinstance(o, ModelType):
# if o in [ModelType.CautiousConsequences, ModelType.BraveConsequences, ModelType.StableModel]:
# return {"__enum__": str(o)}
# return super().default(o)
# elif isinstance(o, Symbol):
# x = symbol_to_dict(o)
# return x
# return super().default(o)


def deserialize(data: str, *args, **kwargs):
Expand Down

0 comments on commit 8d186c7

Please sign in to comment.