Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rustworkx acceleration #381

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c1ff2f8
Add __slots__ and replace PopulatedGraph.degree with direct dictionar…
InnovativeInventor Jan 6, 2022
b72238f
Switch to using more performant/preferred lookup call
InnovativeInventor Jan 6, 2022
3692209
Add retworkx converter
InnovativeInventor Jan 10, 2022
a5e90af
Use graph.node_indicies in PopulatedGraph
InnovativeInventor Jan 10, 2022
ea2cb12
Add keep_attributes flag (will be released in retworkx)
InnovativeInventor Jan 12, 2022
a74c6c2
Naive bipartition_tree retworkx impl
InnovativeInventor Jan 12, 2022
4f99cd4
Improve subgraph instantiation times
InnovativeInventor Jan 13, 2022
d3f0dc4
More perf improvements
InnovativeInventor Jan 13, 2022
190d2ed
Stop using recursive_tree_part
InnovativeInventor Jan 13, 2022
87b3905
Revert "Stop using recursive_tree_part"
InnovativeInventor Jan 13, 2022
d9b3c23
Resolve memory leak issue with caching
InnovativeInventor Jan 13, 2022
73e9990
Prevent unnecessary subgraphing
InnovativeInventor Jan 13, 2022
3b3f36c
Maybe faster record perf (todo: benchmark carefully)
InnovativeInventor Jan 14, 2022
9c2f38c
Re-enable lookup cache
InnovativeInventor Feb 15, 2022
6484374
Use dict.update operator to support older versions of Python
InnovativeInventor Mar 28, 2022
eeb317b
Fix misspelled node_indicies [sic]
InnovativeInventor May 16, 2022
f1310d9
Fix linter issues
InnovativeInventor May 16, 2022
c9d06ba
Add retworkx as setup.py dep
InnovativeInventor May 17, 2022
063e077
Reduce overhead of to_series call
InnovativeInventor May 18, 2022
762b8ac
Switch to __getattr__ call to reduce overhead of attr fetches
InnovativeInventor May 18, 2022
1ea4df9
Force self.mapping to maintain sorted order invariant for faster seri…
InnovativeInventor May 18, 2022
bd7b604
Update retworkx call to use renamed function
InnovativeInventor May 23, 2022
bc1356d
Rename bipartition_graph to bipartition_graph_mst
InnovativeInventor May 27, 2022
48801ed
Make rustworkx (née retworkx) an optional dependency
pjrule Apr 25, 2023
37831ba
Revert changes to Partition class
pjrule Apr 25, 2023
1ce30d0
WIP: Use gerrychain.rs for acceleration
pjrule Apr 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 68 additions & 11 deletions gerrychain/graph/graph.py
@@ -1,6 +1,6 @@
import functools
import json
from typing import Any
from typing import Any, Tuple
import warnings

import networkx
Expand All @@ -19,6 +19,13 @@ def json_serialize(input_object):
if pd.api.types.is_integer_dtype(input_object): # handle int64
return int(input_object)

try:
from gerrychain_rs import rustworkx
except ImportError:
_has_rust_extensions = False
else:
_has_rust_extensions = True


class Graph(networkx.Graph):
"""Represents a graph to be partitioned. It is based on :class:`networkx.Graph`.
Expand Down Expand Up @@ -357,24 +364,57 @@ class FrozenGraph:
This speeds up chain runs and prevents having to deal with cache invalidation issues.
This class behaves slightly differently than :class:`Graph` or :class:`networkx.Graph`.
"""

__slots__ = ["graph", "size"]

def __init__(self, graph: Graph):
__slots__ = [
"graph",
"size",
"pygraph",
"networkx_rustworkx_mapping",
"rustworkx_networkx_mapping"
]

def __init__(self, graph: Graph, pygraph: "rustworkx.PyGraph" = None):
self.graph = networkx.classes.function.freeze(graph)
self.graph.join = frozen
self.graph.add_data = frozen

self.size = len(self.graph)

if _has_rust_extensions:
if graph.is_directed():
raise ValueError("Frozen graphs must be undirected.")

if pygraph is None:
# adapted from `rustworkx.networkx_converter`.
self.pygraph = rustworkx.PyGraph(multigraph=graph.is_multigraph())
nodes = list(graph.nodes)
node_indices = dict(zip(nodes, self.pygraph.add_nodes_from(nodes)))
self.pygraph.add_edges_from(
[(node_indices[x[0]], node_indices[x[1]], x[2]) for x in graph.edges(data=True)]
)

for node, node_index in node_indices.items():
attributes = graph.nodes[node]
attributes["__networkx_node__"] = node
self.pygraph[node_index] = attributes
else:
self.pygraph = pygraph

self.rustworkx_networkx_mapping = {
n: self.pygraph[n]["__networkx_node__"] for n in self.pygraph.node_indexes()
}
self.networkx_rustworkx_mapping = {
self.pygraph[n]["__networkx_node__"]: n for n in self.pygraph.node_indexes()
}
else:
self.pygraph = None
self.rustworkx_networkx_mapping = None
self.networkx_rustworkx_mapping = None

def __len__(self):
return self.size

def __getattribute__(self, __name: str) -> Any:
try:
return object.__getattribute__(self, __name)
except AttributeError:
return object.__getattribute__(self.graph, __name)
def __getattr__(self, __name: str) -> Any:
return getattr(self.graph, __name)

def __getitem__(self, __name: str) -> Any:
return self.graph[__name]
Expand Down Expand Up @@ -403,4 +443,21 @@ def lookup(self, node, field):
return self.graph.nodes[node][field]

def subgraph(self, nodes):
return FrozenGraph(self.graph.subgraph(nodes))
if self.pygraph is None:
return FrozenGraph(self.graph.subgraph(nodes))

return FrozenGraph(
self.graph.subgraph(nodes),
self.pygraph.subgraph(
[self.networkx_rustworkx_mapping[x] for x in nodes]
)
)

def pygraph_pop_lookup(self, field: str):
if self.pygraph is None:
raise ValueError("No rustworkx graph available.")

attrs = [0] * len(self.pygraph.node_indexes())
for node in self.pygraph.node_indexes():
attrs[node] = float(self.pygraph[node][field])
return attrs
30 changes: 26 additions & 4 deletions gerrychain/proposals/tree_proposals.py
@@ -1,13 +1,13 @@
from functools import partial
from ..random import random

from ..random import random
from ..tree import (
recursive_tree_part, bipartition_tree, bipartition_tree_random,
bipartition_tree_random,
_bipartition_tree_random_all, uniform_spanning_tree,
find_balanced_edge_cuts_memoization,
find_balanced_edge_cuts_memoization, bipartition_tree_retworkx,
recursive_tree_part, bipartition_tree,
)


def recom(
partition, pop_col, pop_target, epsilon, node_repeats=1, method=bipartition_tree
):
Expand Down Expand Up @@ -58,6 +58,28 @@ def recom(
return partition.flip(flips)


def recom_rust(partition, pop_col, pop_target, epsilon):
"""Accelerated ReCom proposal (experimental, requires GerryChain.rs)."""
edge = random.choice(tuple(partition["cut_edges"]))
parts_to_merge = (partition.assignment.mapping[edge[0]], partition.assignment.mapping[edge[1]])

subgraph = partition.graph.subgraph(
partition.parts[parts_to_merge[0]] | partition.parts[parts_to_merge[1]]
)

flips_left, flips_right = bipartition_tree_retworkx(
subgraph,
pop_col=pop_col,
pop_target=pop_target,
epsilon=epsilon
)

flips = {node: parts_to_merge[0] for node in flips_left}
flips.update({node: parts_to_merge[1] for node in flips_right})

return partition.flip(flips)


def reversible_recom(partition, pop_col, pop_target, epsilon,
balance_edge_fn=find_balanced_edge_cuts_memoization, M=1,
repeat_until_valid=False, choice=random.choice):
Expand Down
72 changes: 60 additions & 12 deletions gerrychain/tree.py
@@ -1,3 +1,5 @@
from gerrychain.graph.graph import FrozenGraph

import networkx as nx
from networkx.algorithms import tree

Expand All @@ -6,6 +8,13 @@
from collections import deque, namedtuple
from typing import Any, Callable, Dict, List, Optional, Set, Union, Sequence

try:
import gerrychain_rs
except ImportError:
_has_rust_extensions = False
else:
_has_rust_extensions = True


def predecessors(h: nx.Graph, root: Any) -> Dict:
return {a: b for a, b in nx.bfs_predecessors(h, root)}
Expand Down Expand Up @@ -78,7 +87,7 @@ def __init__(
self.tot_pop = sum(self.population.values())
self.ideal_pop = ideal_pop
self.epsilon = epsilon
self._degrees = {node: graph.degree(node) for node in graph.nodes}
self.degrees = {node: graph.degree(node) for node in graph.nodes}

def __iter__(self):
return iter(self.graph)
Expand All @@ -89,7 +98,7 @@ def degree(self, node) -> int:
def contract_node(self, node, parent) -> None:
self.population[parent] += self.population[node]
self.subsets[parent] |= self.subsets[node]
self._degrees[parent] -= 1
self.degrees[parent] -= 1

def has_ideal_population(self, node) -> bool:
return (
Expand All @@ -103,20 +112,20 @@ def has_ideal_population(self, node) -> bool:
def find_balanced_edge_cuts_contraction(
h: PopulatedGraph, choice: Callable = random.choice) -> List[Cut]:
# this used to be greater than 2 but failed on small grids:(
root = choice([x for x in h if h.degree(x) > 1])
root = choice([x for x in h if h.degrees[x] > 1])
# BFS predecessors for iteratively contracting leaves
pred = predecessors(h.graph, root)

cuts = []
leaves = deque(x for x in h if h.degree(x) == 1)
leaves = deque(x for x in h if h.degrees[x] == 1)
while len(leaves) > 0:
leaf = leaves.popleft()
if h.has_ideal_population(leaf):
cuts.append(Cut(edge=(leaf, pred[leaf]), subset=h.subsets[leaf].copy()))
# Contract the leaf:
parent = pred[leaf]
h.contract_node(leaf, parent)
if h.degree(parent) == 1 and parent != root:
if h.degrees[parent] == 1 and parent != root:
leaves.append(parent)
return cuts

Expand Down Expand Up @@ -171,6 +180,39 @@ def part_nodes(start):
return cuts


def bipartition_tree_rust(
graph: FrozenGraph,
pop_col: str,
pop_target: float,
epsilon: float,
choice=random.choice
):
"""This function finds a balanced 2-partition of a graph by drawing a
spanning tree and finding an edge to cut that leaves at most an epsilon
imbalance between the populations of the parts.

Uses Rust extensions (GerryChain.rs).
"""
if not _has_rust_extensions:
raise ImportError(
"GerryChain.rs is required to use accelerated tree functions."
)

pops = graph.pygraph_pop_lookup(pop_col)
balanced_node_choices = gerrychain_rs.bipartition_graph_mst(
graph.pygraph,
lambda _: random.random(),
pops,
float(pop_target),
float(epsilon)
)
balanced_nodes = {
graph.rustworkx_networkx_mapping[x]
for x in choice(balanced_node_choices)[1]
}
return (balanced_nodes, graph.node_indices - balanced_nodes)


def bipartition_tree(
graph: nx.Graph,
pop_col: str,
Expand All @@ -183,7 +225,7 @@ def bipartition_tree(
choice: Callable = random.choice,
max_attempts: Optional[int] = None
) -> Set:
"""This function finds a balanced 2 partition of a graph by drawing a
"""This function finds a balanced 2-partition of a graph by drawing a
spanning tree and finding an edge to cut that leaves at most an epsilon
imbalance between the populations of the parts. If a root fails, new roots
are tried until node_repeats in which case a new tree is drawn.
Expand All @@ -208,7 +250,7 @@ def bipartition_tree(
:param choice: :func:`random.choice`. Can be substituted for testing.
:param max_atempts: The max number of attempts that should be made to bipartition.
"""
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}
populations = {node: graph.lookup(node, pop_col) for node in graph.node_indices}

possible_cuts = []
if spanning_tree is None:
Expand Down Expand Up @@ -246,7 +288,7 @@ def _bipartition_tree_random_all(
max_attempts: Optional[int] = None
):
"""Randomly bipartitions a graph and returns all cuts."""
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}
populations = {node: graph.lookup(node, pop_col) for node in graph.node_indices}

possible_cuts = []
if spanning_tree is None:
Expand Down Expand Up @@ -369,8 +411,14 @@ def recursive_tree_part(
for part in parts[:-1]:
min_pop = max(pop_target * (1 - epsilon), pop_target * (1 - epsilon) - debt)
max_pop = min(pop_target * (1 + epsilon), pop_target * (1 + epsilon) - debt)

if len(parts[:-1]) == 1: # prevent unnecessary subgraphing
subgraph = graph
else:
subgraph = graph.subgraph(remaining_nodes)

nodes = method(
graph.subgraph(remaining_nodes),
subgraph,
pop_col=pop_col,
pop_target=(min_pop + max_pop) / 2,
epsilon=(max_pop - min_pop) / (2 * pop_target),
Expand All @@ -383,7 +431,7 @@ def recursive_tree_part(
part_pop = 0
for node in nodes:
flips[node] = part
part_pop += graph.nodes[node][pop_col]
part_pop += graph.lookup(node, pop_col)
debt += part_pop - pop_target
remaining_nodes -= nodes

Expand Down Expand Up @@ -426,7 +474,7 @@ def get_seed_chunks(

chunk_pop = 0
for node in graph.node_indices:
chunk_pop += graph.nodes[node][pop_col]
chunk_pop += graph.lookup(node, pop_col)

while True:
epsilon = abs(epsilon)
Expand Down Expand Up @@ -466,7 +514,7 @@ def get_seed_chunks(

part_pop = 0
for node in remaining_nodes:
part_pop += graph.nodes[node][pop_col]
part_pop += graph.lookup(node, pop_col)
part_pop_as_dist = part_pop / num_chunks_left
fake_epsilon = epsilon
if num_chunks_left != 1:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -6,7 +6,6 @@
long_description = f.read()

requirements = [
# package requirements go here
"pandas",
"scipy",
"networkx",
Expand Down Expand Up @@ -36,6 +35,7 @@
"License :: OSI Approved :: BSD License",
],
extras_require={
'geo': ["shapely>=2.0.1", "geopandas>=0.12.2"]
'geo': ["shapely>=2.0.1", "geopandas>=0.12.2"],
'rust': ["gerrychain_rs>=0.1"],
}
)