Skip to content

Commit

Permalink
Save all warnings in database as json strings
Browse files Browse the repository at this point in the history
  • Loading branch information
stephanzwicknagl committed Feb 22, 2024
1 parent 7a61ea8 commit 6a8e330
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 58 deletions.
8 changes: 2 additions & 6 deletions backend/src/viasp/server/blueprints/api.py
Expand Up @@ -12,7 +12,7 @@
from .dag_api import generate_graph, set_current_graph, wrap_marked_models, \
load_program, load_transformer, load_models, \
load_clingraph_names
from ..database import CallCenter, get_database, set_models, clear_models, save_many_sorts, save_clingraph, clear_clingraph, save_transformer, save_warnings, clear_warnings, load_warnings
from ..database import CallCenter, get_database, set_models, clear_models, save_many_sorts, save_clingraph, clear_clingraph, save_transformer, save_warnings, clear_warnings, load_warnings, save_warnings
from ...asp.reify import ProgramAnalyzer
from ...asp.relax import ProgramRelaxer, relax_constraints
from ...shared.model import ClingoMethodCall, StableModel, TransformerTransport
Expand Down Expand Up @@ -123,10 +123,6 @@ def set_transformer():
return "ok", 200


def _set_warnings(warnings):
encoding_id = get_or_create_encoding_id()
get_database().save_warnings(warnings, encoding_id)


@bp.route("/control/warnings", methods=["POST", "DELETE", "GET"])
def set_warnings():
Expand Down Expand Up @@ -168,7 +164,7 @@ def save_analyzer_values(analyzer: ProgramAnalyzer):
def show_selected_models():
analyzer = ProgramAnalyzer()
analyzer.add_program(load_program(), load_transformer())
_set_warnings(analyzer.get_filtered())
save_warnings(analyzer.get_filtered())

marked_models = load_models()
marked_models = wrap_marked_models(marked_models,
Expand Down
111 changes: 60 additions & 51 deletions backend/src/viasp/server/database.py
Expand Up @@ -10,7 +10,7 @@
from ..shared.defaults import PROGRAM_STORAGE_PATH, GRAPH_PATH
from ..shared.util import get_or_create_encoding_id
from ..shared.event import Event, subscribe
from ..shared.model import ClingoMethodCall, StableModel, Transformation, TransformerTransport
from ..shared.model import ClingoMethodCall, StableModel, Transformation, TransformerTransport, TransformationError



Expand Down Expand Up @@ -163,7 +163,7 @@ def save_transformer(transformer: TransformerTransport):
encoding_id = get_or_create_encoding_id()
get_database().save_transformer(transformer, encoding_id)

def save_warnings(warnings: List[str]):
def save_warnings(warnings: List[TransformationError]):
encoding_id = get_or_create_encoding_id()
get_database().save_warnings(warnings, encoding_id)

Expand Down Expand Up @@ -245,30 +245,30 @@ def __init__(self):
# ENCODING #
# # # # # # #

def save_program(self, program: str, encoding_id: str): #
def save_program(self, program: str, encoding_id: str):
self.cursor.execute(
"""
INSERT OR REPLACE INTO encodings (program, id) VALUES (?, ?)
""", (program, encoding_id))
self.conn.commit()

def add_to_program(self, program: str, encoding_id: str): #
def add_to_program(self, program: str, encoding_id: str):
program = self.load_program(encoding_id) + program
self.cursor.execute(
"""
INSERT OR REPLACE INTO encodings (id, program) VALUES (?, ?)
""", (encoding_id, program))
self.conn.commit()

def load_program(self, encoding_id: str) -> str: #
def load_program(self, encoding_id: str) -> str:
self.cursor.execute(
"""
SELECT program FROM encodings WHERE id = (?)
""", (encoding_id, ))
result = self.cursor.fetchone()
return result[0] if result is not None else ""

def clear_program(self, encoding_id: str): #
def clear_program(self, encoding_id: str):
self.cursor.execute(
"""
DELETE FROM encodings WHERE id = (?)
Expand All @@ -279,20 +279,18 @@ def clear_program(self, encoding_id: str): #
# MODELS #
# # # # # # #

def set_models(self, parsed_models: Sequence[Union[StableModel, str]], encoding_id: str): #
def set_models(self, parsed_models: Sequence[Union[StableModel, str]],
encoding_id: str):
self.clear_models(encoding_id)
for model in parsed_models:
# if isinstance(model, str):
# json_model = model
# else:
json_model = current_app.json.dumps(model)
self.cursor.execute(
"""
INSERT INTO models (encoding_id, model) VALUES (?, ?)
""", (encoding_id, json_model))
self.conn.commit()

def load_models(self, encoding_id: str) -> List[StableModel]: #
def load_models(self, encoding_id: str) -> List[StableModel]:
self.cursor.execute(
"""
SELECT model FROM models WHERE encoding_id = (?)
Expand All @@ -301,7 +299,7 @@ def load_models(self, encoding_id: str) -> List[StableModel]: #

return [current_app.json.loads(r[0]) for r in result]

def clear_models(self, encoding_id: str): #
def clear_models(self, encoding_id: str):
self.cursor.execute(
"""
DELETE FROM models WHERE encoding_id = (?)
Expand All @@ -313,15 +311,15 @@ def clear_models(self, encoding_id: str): #
# # # # # # # #

def save_graph(self, graph: nx.Graph, hash: str,
sort: List[Transformation], encoding_id: str): #
sort: List[Transformation], encoding_id: str):
self.cursor.execute(
"""
INSERT OR REPLACE INTO graphs (data, hash, sort, encoding_id) VALUES (?, ?, ?, ?)
""", (current_app.json.dumps(nx.node_link_data(graph)), hash, current_app.json.dumps(sort),
encoding_id))
""", (current_app.json.dumps(nx.node_link_data(graph)), hash,
current_app.json.dumps(sort), encoding_id))
self.conn.commit()

def set_current_graph(self, hash: str, encoding_id: str): #
def set_current_graph(self, hash: str, encoding_id: str):
self.cursor.execute(
"""
DELETE FROM current_graph WHERE encoding_id = (?)
Expand All @@ -331,7 +329,7 @@ def set_current_graph(self, hash: str, encoding_id: str): #
(hash, encoding_id))
self.conn.commit()

def get_current_graph(self, encoding_id: str) -> str: #
def get_current_graph(self, encoding_id: str) -> str:
self.cursor.execute(
"""
SELECT hash FROM current_graph WHERE encoding_id = (?)
Expand All @@ -355,28 +353,29 @@ def load_graph(self, hash: str, encoding_id: str) -> nx.DiGraph:
graph_json_str = self.load_graph_json(hash, encoding_id)
return nx.node_link_graph(current_app.json.loads(graph_json_str))

def load_current_graph_json(self, encoding_id: str) -> str: #
def load_current_graph_json(self, encoding_id: str) -> str:
hash = self.get_current_graph(encoding_id)
return self.load_graph_json(hash, encoding_id)

def load_current_graph(self, encoding_id: str) -> nx.DiGraph: #
def load_current_graph(self, encoding_id: str) -> nx.DiGraph:
graph_json_str = self.load_current_graph_json(encoding_id)
# return current_app.json.loads(graph_json_str)
return nx.node_link_graph(current_app.json.loads(graph_json_str))

# # # # # # # #
# SORTS #
# # # # # # # #

def save_many_sorts(self, sorts: List[Tuple[str, List[Transformation], str]]):
def save_many_sorts(self, sorts: List[Tuple[str, List[Transformation],
str]]):
self.cursor.executemany(
"""
INSERT OR REPLACE INTO graphs (hash, data, sort, encoding_id) VALUES (?, ?, ?, ?)
""", [(hash, None, current_app.json.dumps(sort), encoding_id)
for hash, sort, encoding_id in sorts])
self.conn.commit()

def save_sort(self, hash: str, sort: List[Transformation], encoding_id: str):
def save_sort(self, hash: str, sort: List[Transformation],
encoding_id: str):
self.cursor.execute(
"""
INSERT OR REPLACE INTO graphs (hash, data, sort, encoding_id) VALUES (?, ?, ?, ?)
Expand Down Expand Up @@ -436,47 +435,23 @@ def load_all_clingraphs(self, encoding_id: str) -> List[str]:
return [r[0] for r in result]

# # # # # # # #
# GENERAL #
# WARNINGS #
# # # # # # # #

def clear(self):
self.cursor.execute("DELETE FROM encodings")
self.cursor.execute("DELETE FROM models")
self.cursor.execute("DELETE FROM graphs")
self.cursor.execute("DELETE FROM current_graph")
self.cursor.execute("DELETE FROM clingraph")
self.cursor.execute("DELETE FROM transformer")
self.cursor.execute("DELETE FROM warnings")
self.conn.commit()

def save_transformer(self, transformer: TransformerTransport, encoding_id: str):
self.cursor.execute(
"""
INSERT OR REPLACE INTO transformer (transformer, encoding_id) VALUES (?, ?)
""", (current_app.json.dumps(transformer), encoding_id))
self.conn.commit()

def load_transformer(self, encoding_id: str) -> Optional[Transformer]:
self.cursor.execute(
"""
SELECT transformer FROM transformer WHERE encoding_id = (?)
""", (encoding_id, ))
result = self.cursor.fetchone()
return result[0] if result is not None else None

def clear_warnings(self, encoding_id: str):
self.cursor.execute(
"""
DELETE FROM warnings WHERE encoding_id = (?)
""", (encoding_id, ))
self.conn.commit()

def save_warnings(self, warnings: List[str], encoding_id: str):
def save_warnings(self, warnings: List[TransformationError],
encoding_id: str):
for warning in warnings:
self.cursor.execute(
"""
INSERT INTO warnings (encoding_id, warning) VALUES (?, ?)
""", (encoding_id, warning))
""", (encoding_id, current_app.json.dumps(warning)))
self.conn.commit()

def load_warnings(self, encoding_id: str) -> List[str]:
Expand All @@ -485,4 +460,38 @@ def load_warnings(self, encoding_id: str) -> List[str]:
SELECT warning FROM warnings WHERE encoding_id = (?)
""", (encoding_id, ))
result = self.cursor.fetchall()
return [r[0] for r in result]
return [current_app.json.loads(r[0]) for r in result]

# # # # # # # # # # # # # # #
# REGISTERED TRANFORMER #
# # # # # # # # # # # # # # #

def save_transformer(self, transformer: TransformerTransport,
encoding_id: str):
self.cursor.execute(
"""
INSERT OR REPLACE INTO transformer (transformer, encoding_id) VALUES (?, ?)
""", (current_app.json.dumps(transformer), encoding_id))
self.conn.commit()

def load_transformer(self, encoding_id: str) -> Optional[Transformer]:
self.cursor.execute(
"""
SELECT transformer FROM transformer WHERE encoding_id = (?)
""", (encoding_id, ))
result = self.cursor.fetchone()
return current_app.json.loads(result[0]) if result is not None else None

# # # # # # # #
# GENERAL #
# # # # # # # #

def clear(self):
self.cursor.execute("DELETE FROM encodings")
self.cursor.execute("DELETE FROM models")
self.cursor.execute("DELETE FROM graphs")
self.cursor.execute("DELETE FROM current_graph")
self.cursor.execute("DELETE FROM clingraph")
self.cursor.execute("DELETE FROM transformer")
self.cursor.execute("DELETE FROM warnings")
self.conn.commit()
25 changes: 24 additions & 1 deletion backend/test/test_database.py
Expand Up @@ -8,7 +8,7 @@
from viasp.shared.util import hash_from_sorted_transformations
from viasp.asp.reify import reify_list
from viasp.asp.justify import build_graph
from viasp.shared.model import TransformerTransport
from viasp.shared.model import TransformerTransport, TransformationError, FailedReason
from viasp.exampleTransformer import Transformer as ExampleTransfomer


Expand Down Expand Up @@ -175,6 +175,29 @@ def test_clingraph_database():
assert len(r) == 0


def test_warnings(app_context, load_analyzer, program_simple):
db = GraphAccessor()

encoding_id = "test"
analyzer = load_analyzer(program_simple)
some_ast = analyzer.rules[0]
warnings = [
TransformationError(ast=some_ast, reason=FailedReason.FAILURE),
TransformationError(ast=some_ast, reason=FailedReason.FAILURE)
]
r = db.load_warnings(encoding_id)
assert type(r) == list
assert len(r) == 0
db.save_warnings(warnings, encoding_id)
r = db.load_warnings(encoding_id)
assert type(r) == list
assert len(r) == 2
db.clear_warnings(encoding_id)
r = db.load_warnings(encoding_id)
assert type(r) == list
assert len(r) == 0


@pytest.mark.skip(reason="Transformer not registered bc of base exception?")
def test_transformer_database(app_context):
db = GraphAccessor()
Expand Down

0 comments on commit 6a8e330

Please sign in to comment.