Skip to content

Commit

Permalink
Make rustworkx (née retworkx) an optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
pjrule committed Apr 25, 2023
1 parent bc1356d commit 48801ed
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions gerrychain/graph/graph.py
Expand Up @@ -4,7 +4,6 @@
import warnings

import networkx
import retworkx
from networkx.classes.function import frozen
from networkx.readwrite import json_graph
import pandas as pd
Expand All @@ -20,6 +19,13 @@ def json_serialize(input_object):
if pd.api.types.is_integer_dtype(input_object): # handle int64
return int(input_object)

try:
import rustworkx
except ImportError:
_has_rustworkx = False
else:
_has_rustworkx = True


class Graph(networkx.Graph):
"""Represents a graph to be partitioned. It is based on :class:`networkx.Graph`.
Expand Down Expand Up @@ -366,32 +372,29 @@ class FrozenGraph:
"rustworkx_networkx_mapping"
]

def __init__(
self,
graph: Graph,
pygraph: retworkx.PyGraph = None,
mappings: Tuple[dict, dict] = None
):
def __init__(self, graph: Graph, pygraph: "rustworkx.PyGraph" = None):
self.graph = networkx.classes.function.freeze(graph)
self.graph.join = frozen
self.graph.add_data = frozen

self.size = len(self.graph)

if pygraph:
self.pygraph = pygraph
else:
self.pygraph = retworkx.networkx_converter(graph, keep_attributes=True)
if _has_rustworkx:
if pygraph is None:
self.pygraph = rustworkx.networkx_converter(graph, keep_attributes=True)
else:
self.pygraph = pygraph

if mappings:
self.retworkx_networkx_mapping, self.networkx_retworkx_mapping = mappings
else:
self.retworkx_networkx_mapping = {
self.rustworkx_networkx_mapping = {
n: self.pygraph[n]["__networkx_node__"] for n in self.pygraph.node_indexes()
}
self.networkx_retworkx_mapping = {
self.networkx_rustworkx_mapping = {
self.pygraph[n]["__networkx_node__"]: n for n in self.pygraph.node_indexes()
}
else:
self.pygraph = None
self.rustworkx_networkx_mapping = None
self.networkx_rustworkx_mapping = None

def __len__(self):
return self.size
Expand Down Expand Up @@ -426,15 +429,20 @@ def lookup(self, node, field):
return self.graph.nodes[node][field]

def subgraph(self, nodes):
if self.pygraph is None:
return FrozenGraph(self.graph.subgraph(nodes))

return FrozenGraph(
self.graph.subgraph(nodes),
self.pygraph.subgraph(
[self.networkx_retworkx_mapping[x] for x in nodes]
[self.networkx_rustworkx_mapping[x] for x in nodes]
)
)

# @functools.cache
def pygraph_pop_lookup(self, field: str):
if self.pygraph is None:
raise ValueError("No rustworkx graph available.")

attrs = [0] * len(self.pygraph.node_indexes())
for node in self.pygraph.node_indexes():
attrs[node] = float(self.pygraph[node][field])
Expand Down

0 comments on commit 48801ed

Please sign in to comment.