Skip to content

Commit

Permalink
Add support for script input
Browse files Browse the repository at this point in the history
External script's functions are now supported without yielding warning.

Minor fixes to transformers using the dependency collector.

Resolves: #63
  • Loading branch information
stephanzwicknagl committed Mar 3, 2024
1 parent b7eb72d commit 3d777d2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 17 deletions.
2 changes: 1 addition & 1 deletion backend/src/viasp/asp/ast_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
ASTType.Literal, ASTType.HeadAggregate, ASTType.HeadAggregateElement, ASTType.BodyAggregate, ASTType.Aggregate,
ASTType.ConditionalLiteral, ASTType.Guard, ASTType.Comparison, ASTType.SymbolicAtom, ASTType.Function,
ASTType.BodyAggregateElement, ASTType.BooleanConstant, ASTType.SymbolicAtom, ASTType.Variable, ASTType.SymbolicTerm,
ASTType.Interval, ASTType.UnaryOperation, ASTType.BinaryOperation, ASTType.Defined, ASTType.External, ASTType.ProjectAtom, ASTType.ProjectSignature, ASTType.ShowTerm, ASTType.Minimize
ASTType.Interval, ASTType.UnaryOperation, ASTType.BinaryOperation, ASTType.Defined, ASTType.External, ASTType.ProjectAtom, ASTType.ProjectSignature, ASTType.ShowTerm, ASTType.Minimize, ASTType.Script
}

UNSUPPORTED_TYPES = {
Expand Down
25 changes: 14 additions & 11 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ def visit(self, ast: AST, *args: Any, **kwargs: Any) -> Union[AST, None]:

class DependencyCollector(Transformer):

def __init__(self, **kwargs):
self.compound_atoms_types: List = [ASTType.Aggregate, ASTType.BodyAggregate, ASTType.Comparison]
self.in_analyzer = kwargs.get("in_analyzer", False)

def visit_ConditionalLiteral(
self,
conditional_literal: ast.ConditionalLiteral, # type: ignore
Expand Down Expand Up @@ -129,16 +133,15 @@ def visit_Literal(
**kwargs: Any) -> AST:
conditions: List[AST] = kwargs.get("conditions", [])
positive_conditions: List[AST] = kwargs.get("positive_conditions", [])
in_analyzer = kwargs.get("in_analyzer", False)
in_aggregate = kwargs.get("in_aggregate", False)

if (in_analyzer and literal.atom.ast_type
not in [ASTType.Aggregate, ASTType.BodyAggregate]):
# all non-aggregate Literals in the rule body are conditions of the rule
if (self.in_analyzer
and literal.atom.ast_type not in self.compound_atoms_types):
# all non-compound Literals in the rule body are conditions of the rule
conditions.append(literal)
if literal.sign == ast.Sign.NoSign and not in_aggregate:
positive_conditions.append(literal)
if (not in_analyzer and not in_aggregate):
if (not self.in_analyzer and not in_aggregate):
# add all Literals outside of aggregates from rule body to justifier rule body
conditions.append(literal)
return literal.update(**self.visit_children(literal, **kwargs))
Expand All @@ -165,7 +168,8 @@ class ProgramAnalyzer(DependencyCollector, FilteredTransformer):
"""

def __init__(self):
super().__init__()
DependencyCollector.__init__(self, in_analyzer=True)
FilteredTransformer.__init__(self)
self.dependants: Dict[Tuple[str, int],
Set[ast.Rule]] = defaultdict(set) # type: ignore
self.conditions: Dict[Tuple[str, int],
Expand Down Expand Up @@ -560,6 +564,7 @@ def __init__(self,
self.get_conflict_free_variable = get_conflict_free_variable
self.clear_temp_names = clear_temp_names
self.conflict_free_showTerm = conflict_free_showTerm
super().__init__(in_analyzer=False)

def make_loc_lit(self, loc: ast.Location) -> ast.Literal: # type: ignore
loc_fun = ast.Function(loc, str(self.rule_nr), [], False)
Expand Down Expand Up @@ -607,8 +612,6 @@ def process_dependant_intervals(
self, loc: ast.Location,
dependant: Union[ast.Literal, ast.Function]): # type: ignore
if dependant.ast_type == ASTType.Function:
print(f"Type of dependant {dependant.ast_type}, going to make lit",
flush=True)
dependant = ast.Literal(loc, ast.Sign.NoSign, ast.SymbolicAtom(dependant))
if has_an_interval(dependant):
# replace dependant with variable: e.g. (1..3) -> X
Expand Down Expand Up @@ -750,17 +753,17 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def visit_Rule(self, rule: ast.Rule) -> List[AST]: # type: ignore
deps = defaultdict(list)
deps = defaultdict(tuple)
loc = cast(ast.Location, rule.location)
_ = self.visit(rule.head, deps=deps, in_head=True)

if is_fact(rule, deps) or is_constraint(rule):
return [rule]
if not deps:
# if it's a "simple head"
deps[rule.head] = []
deps[rule.head] = ([],[])
new_rules: List[ast.Rule] = [] # type: ignore
for dependant, conditions in deps.items():
for dependant, (conditions, _) in deps.items():
dependant = self.process_dependant_intervals(loc, dependant)

_ = self.visit_sequence(
Expand Down
2 changes: 1 addition & 1 deletion backend/test/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test2():
program = transformer.sort_program(program)
filtered = transformer.get_filtered()
will_work = transformer.will_work()
assert len(filtered) == 1, "Script Statement should be filtered out."
assert len(filtered) == 0, "Script Statement should not be filtered out."
assert will_work == True, "Program with ScriptTerm should work."
# assertProgramEqual(rules, parse_program_to_ast(expected))

Expand Down
4 changes: 0 additions & 4 deletions backend/test/test_reification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from typing import List

from clingo import Control
from clingo.ast import AST, ASTType, parse_string
import pytest

from viasp.asp.ast_types import (SUPPORTED_TYPES, UNSUPPORTED_TYPES,
make_unknown_AST_enum_types)
from viasp.asp.reify import ProgramAnalyzer, transform


Expand Down

0 comments on commit 3d777d2

Please sign in to comment.