Skip to content

Commit

Permalink
Fix Chris comments and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
peterrrock2 committed Mar 1, 2024
1 parent 3294ca9 commit 9dba10a
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 29 deletions.
Binary file modified docs/user/images/gerrymandria_region_ensamble.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/user/images/gerrymandria_water_and_muni_aware.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/user/images/gerrymandria_water_muni_ensamble.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
65 changes: 63 additions & 2 deletions docs/user/recom.rst
Expand Up @@ -236,7 +236,7 @@ also increase the length of our chain to make sure that we have time to mix prop
total_steps=10000
)
Then, we can run the chain and look at the last 20 assignments in the ensemble
Then, we can run the chain and look at the last 40 assignments in the ensemble

.. image:: ./images/gerrymandria_water_muni_ensamble.gif
:width: 400px
Expand Down Expand Up @@ -267,6 +267,67 @@ while also being sensitive to the municipalities
</div>


How the Region Aware Implementation Works
-----------------------------------------

When working with region-aware ReCom chains, it is worth knowing how the spanning tree
of the dual graph is being split. Weights are randomly assigned to the edges of the graph
and then the surcharges are applied to the edges in the graph that span different regions
specified by the ``region_surcharge`` dictionary. So if we have
``region_surcharge={"muni": 0.2, "water": 0.8}``, then the edges that span different
municipalities will be upweighted by 0.2 and the edges that span different water districts
will be upweighted by 0.8. We then draw a minimum spanning tree using Kruskal's algorithm,
which picks the edges interior to the region first before picking the edges that bridge
different regions.

This makes it very likely that each region is largely contained in a connected subtree
attached to a bridge node. Thus, when we make a cut, the regions attached to the
bridge node are more likely to be (mostly) preserved in the subtree on either side
of the cut.

In the implementation of :meth:`~gerrychain.tree.biparition_tree` we further bias this
choice by deterministically selecting bridge edges first. In the event that multiple
types of regions are specified, the surcharges are added together, and edges are selected
first by the number of types of regions that they span, and then by the surcharge added to
those weights. So, if we have a region surcharge dictionary of ``{"a": 1, "b": 4, "c": 2}``
then we we look for edges according to the order

- ("a", "b", "c")
- ("b", "c")
- ("a", "b")
- ("a", "c")
- ("b")
- ("c")
- ("a")
- random

where the tuples indicate that a desired cut edge bridges both types of region in
the tuple. In the event that this is not the desired behaviour, then the user can simply
alter the ``cut_choice`` function in the constraints to be different. So, if the user
would prefer the cut edge to be a random edge with no deference to bridge edges,
then they might use ``random.choice()`` in the following way:

.. code-block:: python
proposal = partial(
recom,
pop_col="TOTPOP",
pop_target=ideal_population,
epsilon=0.01,
node_repeats=1,
region_surcharge={
"muni": 2.0,
"water_dist": 2.0
},
method = partial(
bipartition_tree,
cut_choice = random.choice,
)
)
**Note**: When ``region_surcharge`` is not specified, ``bipartition_tree`` will behave as if
``cut_choice`` is set to ``random.choice``.


.. .. attention::

Expand All @@ -288,7 +349,7 @@ while also being sensitive to the municipalities
.. the surcharges are in the range :math:`[0,1]`, then the surcharges from the surcharge
.. dictionary are added to them. In the event that
.. many edges within the tree have a surcharge above 1, then it can sometimes
.. cause the biparitioning step to stall.
.. cause the bipartitioning step to stall.
What to do if the Chain Gets Stuck
Expand Down
51 changes: 27 additions & 24 deletions gerrychain/tree.py
Expand Up @@ -210,7 +210,7 @@ def __repr__(self) -> str:
Cut.__doc__ = "Represents a cut in a graph."
Cut.edge.__doc__ = "The edge where the cut is made. Defaults to None."
Cut.weight.__doc__ = "The weight assigned to the edge (if any). Defaults to None."
Cut.subset.__doc__ = "The subset of nodes on one side of the cut. Defaults to None."
Cut.subset.__doc__ = "The (frozen) subset of nodes on one side of the cut. Defaults to None."


def find_balanced_edge_cuts_contraction(
Expand Down Expand Up @@ -242,7 +242,7 @@ def find_balanced_edge_cuts_contraction(
Cut(
edge=e,
weight=h.graph.edges[e].get("random_weight", random.random()),
subset=h.subsets[leaf].copy()
subset=frozenset(h.subsets[leaf].copy())
)
)
# Contract the leaf:
Expand Down Expand Up @@ -351,7 +351,7 @@ def find_balanced_edge_cuts_memoization(
Cut(
edge=e,
weight=h.graph.edges[e].get("random_weight", wt),
subset=_part_nodes(node, succ)
subset=frozenset(_part_nodes(node, succ))
)
)
elif abs((total_pop - tree_pop) - h.ideal_pop) <= h.ideal_pop * h.epsilon:
Expand All @@ -361,7 +361,7 @@ def find_balanced_edge_cuts_memoization(
Cut(
edge=e,
weight=h.graph.edges[e].get("random_weight", wt),
subset=set(h.graph.nodes) - _part_nodes(node, succ),
subset=frozenset(set(h.graph.nodes) - _part_nodes(node, succ)),
)
)
return cuts
Expand Down Expand Up @@ -390,22 +390,25 @@ def _max_weight_choice(
cut_edge_list: List[Cut]
) -> Cut:
"""
Each Cut object in the list is assigned a random weight
either coming from the implementation of Kruskal's algorithm
Each Cut object in the list is assigned a random weight.
This random weight is either assigned during the call to
the minimum spanning tree algorithm (Kruskal's) algorithm
or it is generated during the selection of the balanced edges
(cf. :meth:`find_balanced_edge_cuts_memoization` and
:meth:`find_balanced_edge_cuts_contraction`).
This function returns the cut with the highest weight.
In the case of a situation where a region aware chain is run,
this will preferentially select for cuts that are between
regions, rather than within them (the likelihood of this
In the case where a region aware chain is run, this will
preferentially select for cuts that span different regions, rather
than cuts that are interior to that region (the likelihood of this
is generally controlled by the ``region_surcharge`` parameter).
In all other cases, this is effectively the same as calling
random.choice() on the list of cuts since all of the weights
In any case where the surcharges are either not set or zero,
this is effectively the same as calling random.choice() on the
list of cuts. Under the above conditions, all of the weights
on the cuts are randomly generated on the interval [0,1], and
there is no mechanism in place weight any cut edge over another.
there is no outside force that might make the weight assigned
to a particular type of cut higher than another.
:param cut_edge_list: A list of Cut objects. Each object has an
edge, a weight, and a subset attribute.
Expand Down Expand Up @@ -452,14 +455,14 @@ def _region_preferred_max_weight_choice(
"""
This function is used in the case of a region-aware chain. It
is similar to the as :meth:`_max_weight_choice` function except
that it will preferentially select one of the cuts that
has the highest surcharge preferentially. So, if we have a
weight dict of the form ``{region1: wt1, region2: wt2}`` , then
this function first looks for a cut that is a cut edge for both
``region1`` and ``region2`` and then selects the one with the
highest weight. If no such cut exists, then it will then look for
a cut that is a cut edge for the region with the highest surcharge
(presumably the region that we care more about not splitting).
that it will preferentially select one of the cuts that has the
highest surcharge. So, if we have a weight dict of the form
``{region1: wt1, region2: wt2}`` , then this function first looks
for a cut that is a cut edge for both ``region1`` and ``region2``
and then selects the one with the highest weight. If no such cut
exists, then it will then look for a cut that is a cut edge for the
region with the highest surcharge (presumably the region that we care
more about not splitting).
In the case of 3 regions, it will first look for a cut that is a
cut edge for all 3 regions, then for a cut that is a cut edge for
Expand Down Expand Up @@ -537,7 +540,7 @@ def bipartition_tree(
max_attempts: Optional[int] = 100000,
warn_attempts: int = 1000,
allow_pair_reselection: bool = False,
cut_choice: Callable = random.choice
cut_choice: Callable = _region_preferred_max_weight_choice
) -> Set:
"""
This function finds a balanced 2 partition of a graph by drawing a
Expand Down Expand Up @@ -572,8 +575,8 @@ def bipartition_tree(
:param balance_edge_fn: The function to find balanced edge cuts. Defaults to
:func:`find_balanced_edge_cuts_memoization`.
:type balance_edge_fn: Callable, optional
:param choice: The function to make a random choice. Passed to ``balance_edge_fn``.
Can be substituted for testing.
:param choice: The function to make a random choice of root node for the population
tree. Passed to ``balance_edge_fn``. Can be substituted for testing.
Defaults to :func:`random.random()`.
:type choice: Callable, optional
:param max_attempts: The maximum number of attempts that should be made to bipartition.
Expand All @@ -586,7 +589,7 @@ def bipartition_tree(
function to ask it to reselect the pair of nodes to try and recombine. Defaults to False.
:type allow_pair_reselection: bool, optional
:param cut_choice: The function used to select the cut edge from the list of possible
balanced cuts. Defaults to :meth:`_max_weight_choice` .
balanced cuts. Defaults to :meth:`_region_preferred_max_weight_choice` .
:type cut_choice: Callable, optional
:returns: A subset of nodes of ``graph`` (whose induced subgraph is connected). The other
Expand Down
2 changes: 1 addition & 1 deletion tests/test_reproducibility.py
Expand Up @@ -112,5 +112,5 @@ def test_pa_freeze():

# This needs to be changed every time we change the
# tests around
assert hashlib.sha256(result.encode()).hexdigest() == "2e0d148c22f4d2f7f9ae39c8950892f4574ea65d197657c14446c359c5ccfd8e"
assert hashlib.sha256(result.encode()).hexdigest() == "7f355cd0f7c235f4d285db1c7593ba0d4a5558c404b70521c9837125df418384"

4 changes: 2 additions & 2 deletions tests/test_tree.py
Expand Up @@ -53,7 +53,7 @@ def twelve_by_twelve_with_pop():
def test_bipartition_tree_returns_a_subset_of_nodes(graph_with_pop):
ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2
result = bipartition_tree(graph_with_pop, "pop", ideal_pop, 0.25, 10)
assert isinstance(result, set)
assert isinstance(result, frozenset)
assert all(node in graph_with_pop.nodes for node in result)


Expand Down Expand Up @@ -240,7 +240,7 @@ def test_prime_bound():
def test_bipartition_tree_random_returns_a_subset_of_nodes(graph_with_pop):
ideal_pop = sum(graph_with_pop.nodes[node]["pop"] for node in graph_with_pop) / 2
result = bipartition_tree_random(graph_with_pop, "pop", ideal_pop, 0.25, 10)
assert isinstance(result, set)
assert isinstance(result, frozenset)
assert all(node in graph_with_pop.nodes for node in result)


Expand Down

0 comments on commit 9dba10a

Please sign in to comment.