From 56f0c351d10d8c75c1972bb73e34f2b6c0195c2b Mon Sep 17 00:00:00 2001 From: Amil Merchant Date: Wed, 21 Feb 2024 00:40:31 +0000 Subject: [PATCH 1/2] Make functions within partition visible for downstream usecases. --- jax_md/partition.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/jax_md/partition.py b/jax_md/partition.py index ac27cf70..0760d8f1 100644 --- a/jax_md/partition.py +++ b/jax_md/partition.py @@ -206,7 +206,11 @@ def _estimate_cell_capacity(position: Array, return int(cell_capacity * buffer_size_multiplier) +<<<<<<< Updated upstream def _shift_array(arr: Array, dindex: Array) -> Array: +======= +def shift_array(arr: Array, dindex: Array) -> Array: +>>>>>>> Stashed changes if len(dindex) == 2: dx, dy = dindex dz = 0 @@ -231,7 +235,11 @@ def _shift_array(arr: Array, dindex: Array) -> Array: return arr +<<<<<<< Updated upstream def _unflatten_cell_buffer(arr: Array, +======= +def unflatten_cell_buffer(arr: Array, +>>>>>>> Stashed changes cells_per_side: Array, dim: int) -> Array: if (isinstance(cells_per_side, int) or @@ -380,14 +388,23 @@ def cell_list_fn(position: Array, cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) +<<<<<<< Updated upstream cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) +======= + cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim) + cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim) +>>>>>>> Stashed changes for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1,)) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) +<<<<<<< Updated upstream cell_kwargs[k] = _unflatten_cell_buffer( +======= + cell_kwargs[k] = unflatten_cell_buffer( +>>>>>>> Stashed changes cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) @@ -806,7 +823,7 @@ def cell_list_candidate_fn(cl_id_buffer, positionShape) -> Array: for dindex in _neighboring_cells(dim): if onp.all(dindex == 0): continue - cell_idx += [_shift_array(idx, dindex)] + cell_idx += [shift_array(idx, dindex)] cell_idx = jnp.concatenate(cell_idx, axis=-2) cell_idx = cell_idx[..., jnp.newaxis, :, :] From 5742e235677adc7fe0b0c1ad0fba426e97b119c6 Mon Sep 17 00:00:00 2001 From: Amil Merchant Date: Wed, 21 Feb 2024 00:53:43 +0000 Subject: [PATCH 2/2] Fix stash comments. --- jax_md/partition.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/jax_md/partition.py b/jax_md/partition.py index 0760d8f1..f1aca0e3 100644 --- a/jax_md/partition.py +++ b/jax_md/partition.py @@ -206,11 +206,7 @@ def _estimate_cell_capacity(position: Array, return int(cell_capacity * buffer_size_multiplier) -<<<<<<< Updated upstream -def _shift_array(arr: Array, dindex: Array) -> Array: -======= def shift_array(arr: Array, dindex: Array) -> Array: ->>>>>>> Stashed changes if len(dindex) == 2: dx, dy = dindex dz = 0 @@ -235,11 +231,7 @@ def shift_array(arr: Array, dindex: Array) -> Array: return arr -<<<<<<< Updated upstream -def _unflatten_cell_buffer(arr: Array, -======= def unflatten_cell_buffer(arr: Array, ->>>>>>> Stashed changes cells_per_side: Array, dim: int) -> Array: if (isinstance(cells_per_side, int) or @@ -388,23 +380,14 @@ def cell_list_fn(position: Array, cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) -<<<<<<< Updated upstream - cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) - cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) -======= cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim) cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim) ->>>>>>> Stashed changes for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1,)) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) -<<<<<<< Updated upstream - cell_kwargs[k] = _unflatten_cell_buffer( -======= cell_kwargs[k] = unflatten_cell_buffer( ->>>>>>> Stashed changes cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)