Skip to content

Commit

Permalink
Improve signature for showTerm statements
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanzwicknagl committed Feb 27, 2024
1 parent 7895f5b commit 4f802a1
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
22 changes: 16 additions & 6 deletions backend/src/viasp/asp/reify.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ def make_signature(literal: ast.Literal) -> Tuple[str, int]: # type: ignore
if literal.atom.ast_type in [ASTType.BodyAggregate]:
return literal, 0
unpacked = literal.atom.symbol
if hasattr(unpacked, "ast_type") and unpacked.ast_type == ASTType.Pool:
if unpacked.ast_type in [ASTType.Variable, ASTType.Function]:
return (
unpacked.name,
len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0,
)
if unpacked.ast_type == ASTType.SymbolicTerm:
return (
unpacked.symbol.name,
len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0,
)
if unpacked.ast_type == ASTType.Pool:
unpacked = unpacked.arguments[0]
return (
unpacked.name,
len(unpacked.arguments) if hasattr(unpacked, "arguments") else 0,
)

return (
unpacked.name,
len(unpacked.arguments)
)
raise ValueError(f"Could not make signature of {literal}")

def filter_body_arithmetic(elem: ast.Literal): # type: ignore
elem_ast_type = getattr(getattr(elem, "atom", ""), "ast_type", None)
Expand Down
5 changes: 0 additions & 5 deletions backend/src/viasp/server/blueprints/dag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,11 +393,9 @@ def generate_graph() -> nx.DiGraph:
marked_models = load_models()
marked_models = wrap_marked_models(marked_models,
analyzer.get_conflict_free_showTerm())
print(f"Using marked model: {marked_models}", flush=True)
if analyzer.will_work():
recursion_rules = analyzer.check_positive_recursion()
sorted_program = get_current_sort()
print(F"Sorted program: {sorted_program}", flush=True)
reified: Collection[AST] = reify_list(
sorted_program,
h=analyzer.get_conflict_free_h(),
Expand All @@ -406,11 +404,8 @@ def generate_graph() -> nx.DiGraph:
conflict_free_showTerm=analyzer.get_conflict_free_showTerm(),
get_conflict_free_variable=analyzer.get_conflict_free_variable,
clear_temp_names=analyzer.clear_temp_names)
print(f"Reified program: {list(map(str,reified))}", flush=True)
g = build_graph(marked_models, reified, sorted_program, analyzer,
recursion_rules)
print(f"Got graph with {len(g.nodes)} nodes and {len(g.edges)} edges.",
flush=True)
save_graph(g, hash_from_sorted_transformations(sorted_program),
sorted_program)

Expand Down
8 changes: 0 additions & 8 deletions backend/test/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,14 +309,6 @@ def test_ast_types_do_not_intersect(app_context):
known), "No type should be known and unknown"


def test_disjunction_causes_error_and_doesnt_get_passed():
program = "a; b."

transformer = ProgramAnalyzer()
result = transformer.sort_program(program)
assert len(transformer.get_filtered())
assert len(result) == 0


@pytest.mark.skip(reason="Not implemented yet")
def test_constraints_gets_put_last(app_context):
Expand Down
4 changes: 4 additions & 0 deletions backend/test/test_reification.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ def test_disjunctions_in_head():
assertProgramEqual(transform(rule), parse_program_to_ast(expected))

def test_showTerm_transformed_correctly():
rule = "#show a : b."
expected = "h_showTerm(1, a, (b,)) :- showTerm(a), b."
assertProgramEqual(transform(rule), parse_program_to_ast(expected))

rule = "#show a(X) : b(X)."
expected = "h_showTerm(1, a(X), (b(X),)) :- showTerm(a(X)), b(X)."
assertProgramEqual(transform(rule), parse_program_to_ast(expected))

0 comments on commit 4f802a1

Please sign in to comment.