Skip to content

Commit

Permalink
Merge pull request #386 from mggg/frozen-graph
Browse files Browse the repository at this point in the history
Frozen graph implementation!
  • Loading branch information
gabeschoenbach committed Feb 15, 2022
2 parents 6f4e9b9 + a97beed commit 17fa8bf
Show file tree
Hide file tree
Showing 15 changed files with 145 additions and 37 deletions.
85 changes: 83 additions & 2 deletions gerrychain/graph/graph.py
@@ -1,8 +1,11 @@
import functools
import json
from typing import Any
import warnings

import geopandas as gp
import networkx
from networkx.classes.function import frozen
from networkx.readwrite import json_graph
from shapely.ops import unary_union
from shapely.prepared import prep
Expand All @@ -18,10 +21,14 @@ class Graph(networkx.Graph):
to save and load graphs as JSON files.
"""

def __repr__(self):
return "<Graph [{} nodes, {} edges]>".format(len(self.nodes), len(self.edges))

@classmethod
def from_networkx(cls, graph: networkx.Graph):
g = cls(graph)
return g

@classmethod
def from_json(cls, json_file):
"""Load a graph from a JSON file in the NetworkX json_graph format.
Expand All @@ -31,7 +38,7 @@ def from_json(cls, json_file):
with open(json_file) as f:
data = json.load(f)
g = json_graph.adjacency_graph(data)
graph = cls(g)
graph = cls.from_networkx(g)
graph.issue_warnings()
return graph

Expand Down Expand Up @@ -165,6 +172,22 @@ def from_geodataframe(
graph.add_data(df, columns=cols_to_add)
return graph

def lookup(self, node, field):
"""
Lookup a node/field attribute.
:param node: Node to look up.
:param field: Field to look up.
"""
return self.nodes[node][field]

@property
def node_indices(self):
return set(self.nodes)

@property
def edge_indices(self):
return set(self.edges)

def add_data(self, df, columns=None):
"""Add columns of a DataFrame to a graph as node attributes using
by matching the DataFrame's index to node ids.
Expand Down Expand Up @@ -310,3 +333,61 @@ def convert_geometries_to_geojson(data):
# This is what :func:`geopandas.GeoSeries.to_json` uses under
# the hood.
node[key] = node[key].__geo_interface__


class FrozenGraph:
""" Represents an immutable graph to be partitioned. It is based off :class:`Graph`.
This speeds up chain runs and prevents having to deal with cache invalidation issues.
This class behaves slightly differently than :class:`Graph` or :class:`networkx.Graph`.
"""
__slots__ = [
"graph",
"size"
]

def __init__(self, graph: Graph):
self.graph = networkx.classes.function.freeze(graph)
self.graph.join = frozen
self.graph.add_data = frozen
self.graph.add_data = frozen

self.size = len(self.graph)

def __len__(self):
return self.size

def __getattribute__(self, __name: str) -> Any:
try:
return object.__getattribute__(self, __name)
except AttributeError:
return object.__getattribute__(self.graph, __name)

def __getitem__(self, __name: str) -> Any:
return self.graph[__name]

def __iter__(self):
yield from self.node_indices

@functools.lru_cache(16384)
def neighbors(self, n):
return tuple(self.graph.neighbors(n))

@functools.cached_property
def node_indices(self):
return self.graph.node_indices

@functools.cached_property
def edge_indices(self):
return self.graph.edge_indices

@functools.lru_cache(16384)
def degree(self, n):
return self.graph.degree(n)

@functools.lru_cache(65536)
def lookup(self, node, field):
return self.graph.nodes[node][field]

def subgraph(self, nodes):
return FrozenGraph(self.graph.subgraph(nodes))
3 changes: 2 additions & 1 deletion gerrychain/grid.py
Expand Up @@ -3,6 +3,7 @@
import networkx

from gerrychain.partition import Partition
from gerrychain.graph import Graph
from gerrychain.updaters import (
Tally,
boundary_nodes,
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(
"""
if dimensions:
self.dimensions = dimensions
graph = create_grid_graph(dimensions, with_diagonals)
graph = Graph.from_networkx(create_grid_graph(dimensions, with_diagonals))

if not assignment:
thresholds = tuple(math.floor(n / 2) for n in self.dimensions)
Expand Down
14 changes: 13 additions & 1 deletion gerrychain/partition/partition.py
@@ -1,5 +1,8 @@
import json
import geopandas
import networkx

from gerrychain.graph.graph import FrozenGraph, Graph
from ..updaters import compute_edge_flows, flows_from_changes, cut_edges
from .assignment import get_assignment
from .subgraphs import SubgraphView
Expand Down Expand Up @@ -50,7 +53,16 @@ def __init__(
self.subgraphs = SubgraphView(self.graph, self.parts)

def _first_time(self, graph, assignment, updaters, use_cut_edges):
self.graph = graph
if isinstance(graph, Graph):
self.graph = FrozenGraph(graph)
elif isinstance(graph, networkx.Graph):
graph = Graph.from_networkx(graph)
self.graph = FrozenGraph(graph)
elif isinstance(graph, FrozenGraph):
self.graph = graph
else:
raise TypeError("Unsupported Graph object")

self.assignment = get_assignment(assignment, graph)

if set(self.assignment) != set(graph):
Expand Down
2 changes: 1 addition & 1 deletion gerrychain/proposals/spectral_proposals.py
Expand Up @@ -15,7 +15,7 @@ def spectral_cut(graph, part_labels, weight_type, lap_type):
n = len(nlist)

if weight_type == "random":
for edge in graph.edges():
for edge in graph.edge_indices:
graph.edges[edge]["weight"] = random.random()

if lap_type == "normalized":
Expand Down
2 changes: 1 addition & 1 deletion gerrychain/proposals/tree_proposals.py
Expand Up @@ -46,7 +46,7 @@ def recom(
)

flips = recursive_tree_part(
subgraph,
subgraph.graph,
parts_to_merge,
pop_col=pop_col,
pop_target=pop_target,
Expand Down
23 changes: 12 additions & 11 deletions gerrychain/tree.py
Expand Up @@ -15,7 +15,7 @@ def successors(h, root):

def random_spanning_tree(graph):
""" Builds a spanning tree chosen by Kruskal's method using random weights.
:param graph: Networkx Graph
:param graph: FrozenGraph
Important Note:
The key is specifically labelled "random_weight" instead of the previously
Expand All @@ -24,7 +24,7 @@ def random_spanning_tree(graph):
This meant that the laplacian would change for the graph step to step,
something that we do not intend!!
"""
for edge in graph.edges:
for edge in graph.edge_indices:
graph.edges[edge]["random_weight"] = random.random()

spanning_tree = tree.maximum_spanning_tree(
Expand All @@ -39,14 +39,14 @@ def uniform_spanning_tree(graph, choice=random.choice):
:param graph: Networkx Graph
:param choice: :func:`random.choice`
"""
root = choice(list(graph.nodes))
root = choice(graph.node_indices)
tree_nodes = set([root])
next_node = {root: None}

for node in graph.nodes:
for node in graph.node_indices:
u = node
while u not in tree_nodes:
next_node[u] = choice(list(nx.neighbors(graph, u)))
next_node[u] = choice(graph.neighbors(u))
u = next_node[u]

u = node
Expand All @@ -65,12 +65,12 @@ def uniform_spanning_tree(graph, choice=random.choice):
class PopulatedGraph:
def __init__(self, graph, populations, ideal_pop, epsilon):
self.graph = graph
self.subsets = {node: {node} for node in graph}
self.subsets = {node: {node} for node in graph.node_indices}
self.population = populations.copy()
self.tot_pop = sum(self.population.values())
self.ideal_pop = ideal_pop
self.epsilon = epsilon
self._degrees = {node: graph.degree(node) for node in graph}
self._degrees = {node: graph.degree(node) for node in graph.node_indices}

def __iter__(self):
return iter(self.graph)
Expand Down Expand Up @@ -194,7 +194,7 @@ def bipartition_tree(
tree is not provided
:param choice: :func:`random.choice`. Can be substituted for testing.
"""
populations = {node: graph.nodes[node][pop_col] for node in graph}
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}

possible_cuts = []
if spanning_tree is None:
Expand Down Expand Up @@ -224,7 +224,7 @@ def _bipartition_tree_random_all(
choice=random.choice,
):
"""Randomly bipartitions a graph and returns all cuts."""
populations = {node: graph.nodes[node][pop_col] for node in graph}
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}

possible_cuts = []
if spanning_tree is None:
Expand Down Expand Up @@ -303,11 +303,12 @@ def recursive_tree_part(
:param epsilon: How far (as a percentage of ``pop_target``) from ``pop_target`` the parts
of the partition can be
:param node_repeats: Parameter for :func:`~gerrychain.tree_methods.bipartition_tree` to use.
:param method: The partition method to use.
:return: New assignments for the nodes of ``graph``.
:rtype: dict
"""
flips = {}
remaining_nodes = set(graph.nodes)
remaining_nodes = graph.node_indices
# We keep a running tally of deviation from ``epsilon`` at each partition
# and use it to tighten the population constraints on a per-partition
# basis such that every partition, including the last partition, has a
Expand Down Expand Up @@ -376,7 +377,7 @@ def get_seed_chunks(
new_epsilon = epsilon

chunk_pop = 0
for node in graph.nodes:
for node in graph.node_indices:
chunk_pop += graph.nodes[node][pop_col]

while True:
Expand Down
4 changes: 2 additions & 2 deletions gerrychain/updaters/county_splits.py
Expand Up @@ -35,8 +35,8 @@ def compute_county_splits(partition, county_field, partition_field):
if not partition.parent:
county_dict = dict()

for node in partition.graph:
county = partition.graph.nodes[node][county_field]
for node in partition.graph.node_indices:
county = partition.graph.lookup(node, county_field)
if county in county_dict:
split, nodes, seen = county_dict[county]
else:
Expand Down
16 changes: 13 additions & 3 deletions gerrychain/updaters/tally.py
Expand Up @@ -9,6 +9,11 @@ class DataTally:
"""An updater for tallying numerical data that is not necessarily stored as
node attributes
"""
__slots__ = [
"data",
"alias",
"_call"
]

def __init__(self, data, alias):
"""
Expand Down Expand Up @@ -54,6 +59,11 @@ def __call__(self, partition, previous=None):
class Tally:
"""An updater for keeping a tally of one or more node attributes.
"""
__slots__ = [
"fields",
"alias",
"dtype"
]

def __init__(self, fields, alias=None, dtype=int):
"""
Expand Down Expand Up @@ -116,12 +126,12 @@ def _update_tally(self, partition):
return new_tally

def _get_tally_from_node(self, partition, node):
return sum(partition.graph.nodes[node][field] for field in self.fields)
return sum(partition.graph.lookup(node, field) for field in self.fields)


def compute_out_flow(graph, fields, flow):
return sum(graph.nodes[node][field] for node in flow["out"] for field in fields)
return sum(graph.lookup(node, field) for node in flow["out"] for field in fields)


def compute_in_flow(graph, fields, flow):
return sum(graph.nodes[node][field] for node in flow["in"] for field in fields)
return sum(graph.lookup(node, field) for node in flow["in"] for field in fields)
1 change: 0 additions & 1 deletion setup.py
Expand Up @@ -31,7 +31,6 @@
install_requires=requirements,
keywords="GerryChain",
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Expand Up @@ -56,7 +56,7 @@ def graph(three_by_three_grid):

@pytest.fixture
def example_partition():
graph = networkx.complete_graph(3)
graph = Graph.from_networkx(networkx.complete_graph(3))
assignment = {0: 1, 1: 1, 2: 2}
partition = Partition(graph, assignment, {"cut_edges": cut_edges})
return partition
Expand Down
5 changes: 3 additions & 2 deletions tests/constraints/test_validity.py
Expand Up @@ -11,14 +11,15 @@
single_flip_contiguous)
from gerrychain.partition import Partition
from gerrychain.partition.partition import get_assignment
from gerrychain.graph import Graph


@pytest.fixture
def contiguous_partition_with_flips():
graph = nx.Graph()
graph.add_nodes_from(range(4))
graph.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0)])
partition = Partition(graph, {0: 0, 1: 1, 2: 1, 3: 0})
partition = Partition(Graph.from_networkx(graph), {0: 0, 1: 1, 2: 1, 3: 0})

# This flip will maintain contiguity.
return partition, {0: 1}
Expand All @@ -29,7 +30,7 @@ def discontiguous_partition_with_flips():
graph = nx.Graph()
graph.add_nodes_from(range(4))
graph.add_edges_from([(0, 1), (1, 2), (2, 3)])
partition = Partition(graph, {0: 0, 1: 1, 2: 1, 3: 0})
partition = Partition(Graph.from_networkx(graph), {0: 0, 1: 1, 2: 1, 3: 0})

# This flip will maintain discontiguity.
return partition, {1: 0}
Expand Down

0 comments on commit 17fa8bf

Please sign in to comment.