Skip to content

Commit

Permalink
Adds DISTINCT (#40)
Browse files Browse the repository at this point in the history
* Implements `DISTINCT`

* Adds unit tests for `DISTINCT`
  • Loading branch information
jackboyla committed May 2, 2024
1 parent 31001b7 commit 6f4e32c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 7 deletions.
31 changes: 24 additions & 7 deletions grandcypher/__init__.py
Expand Up @@ -80,11 +80,12 @@
return_clause : "return"i entity_id ("," entity_id)*
| "return"i entity_id ("," entity_id)* limit_clause
| "return"i entity_id ("," entity_id)* skip_clause
| "return"i entity_id ("," entity_id)* skip_clause limit_clause
return_clause : "return"i distinct_return? entity_id ("," entity_id)*
| "return"i distinct_return? entity_id ("," entity_id)* limit_clause
| "return"i distinct_return? entity_id ("," entity_id)* skip_clause
| "return"i distinct_return? entity_id ("," entity_id)* skip_clause limit_clause
distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
skip_clause : "skip"i NUMBER
Expand Down Expand Up @@ -319,6 +320,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._matche_paths = None
self._return_requests = []
self._return_edges = {}
self._distinct = False
self._limit = limit
self._skip = 0
self._max_hop = 100
Expand Down Expand Up @@ -391,14 +393,18 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
result[data_path] = list(ret)[offset_limit]

return result

def return_clause(self, clause):

def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)

def distinct_return(self, distinct):
self._distinct = True

def limit_clause(self, limit):
limit = int(limit[-1])
self._limit = limit
Expand All @@ -413,7 +419,18 @@ def returns(self, ignore_limit=False):
else:
offset_limit = slice(self._skip, None)

return self._lookup(self._return_requests, offset_limit=offset_limit)
results = self._lookup(self._return_requests, offset_limit=offset_limit)

if self._distinct:

# process distinct for each key in results
distinct_results = {}
for key, values in results.items():
# remove duplicates
distinct_results[key] = list(set(values))
results = distinct_results

return results

def _get_true_matches(self):
if not self._matches:
Expand Down
83 changes: 83 additions & 0 deletions grandcypher/test_queries.py
Expand Up @@ -588,6 +588,89 @@ def test_complex_where(self):
assert res["B"] == ["y", "z"]


class TestDistinct:
def test_basic_distinct(self):
host = nx.DiGraph()
host.add_node("a", name="Alice")
host.add_node("b", name="Bob")
host.add_node("c", name="Alice") # duplicate name

qry = """
MATCH (n)
RETURN DISTINCT n.name
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 2 # should return "Alice" and "Bob" only once
assert "Alice" in res["n.name"] and "Bob" in res["n.name"]


def test_distinct_with_relationships(self):
host = nx.DiGraph()
host.add_node("a", name="Alice")
host.add_node("b", name="Bob")
host.add_node("c", name="Alice") # duplicate name
host.add_edge("a", "b")
host.add_edge("c", "b")

qry = """
MATCH (n)-[]->(b)
RETURN DISTINCT n.name
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 1 # should return "Alice" only once
assert res["n.name"] == ["Alice"]


def test_distinct_with_limit_and_skip(self):
host = nx.DiGraph()
for i in range(5):
host.add_node(f"a{i}", name="Alice")
host.add_node(f"b{i}", name="Bob")

qry = """
MATCH (n)
RETURN DISTINCT n.name SKIP 1 LIMIT 1
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 1 # only one name should be returned
assert res["n.name"] == ["Bob"] # assuming alphabetical order


def test_distinct_on_complex_graph(self):
host = nx.DiGraph()
host.add_node("a", name="Alice")
host.add_node("b", name="Bob")
host.add_node("c", name="Carol")
host.add_node("d", name="Alice") # duplicate name
host.add_edge("a", "b")
host.add_edge("b", "c")
host.add_edge("c", "d")

qry = """
MATCH (n)-[]->(m)
RETURN DISTINCT n.name, m.name
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 3 # should account for paths without considering duplicate names
assert "Alice" in res["n.name"] and "Bob" in res["n.name"] and "Carol" in res["n.name"]
assert len(res["m.name"]) == 3 # should account for paths without considering duplicate names

def test_distinct_with_attributes(self):
host = nx.DiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Alice", age=30) # same name, different attribute
host.add_node("c", name="Bob", age=25)

qry = """
MATCH (n)
WHERE n.age > 20
RETURN DISTINCT n.name
"""
res = GrandCypher(host).run(qry)
assert len(res["n.name"]) == 2 # "Alice" and "Bob" should be distinct
assert "Alice" in res["n.name"] and "Bob" in res["n.name"]


class TestVariableLengthRelationship:
def test_single_variable_length_relationship(self):
host = nx.DiGraph()
Expand Down

0 comments on commit 6f4e32c

Please sign in to comment.