diff --git a/jax_md/partition.py b/jax_md/partition.py index ac27cf70..f1aca0e3 100644 --- a/jax_md/partition.py +++ b/jax_md/partition.py @@ -206,7 +206,7 @@ def _estimate_cell_capacity(position: Array, return int(cell_capacity * buffer_size_multiplier) -def _shift_array(arr: Array, dindex: Array) -> Array: +def shift_array(arr: Array, dindex: Array) -> Array: if len(dindex) == 2: dx, dy = dindex dz = 0 @@ -231,7 +231,7 @@ def _shift_array(arr: Array, dindex: Array) -> Array: return arr -def _unflatten_cell_buffer(arr: Array, +def unflatten_cell_buffer(arr: Array, cells_per_side: Array, dim: int) -> Array: if (isinstance(cells_per_side, int) or @@ -380,14 +380,14 @@ def cell_list_fn(position: Array, cell_position = cell_position.at[sorted_cell_id].set(sorted_position) sorted_id = jnp.reshape(sorted_id, (N, 1)) cell_id = cell_id.at[sorted_cell_id].set(sorted_id) - cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim) - cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim) + cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim) + cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim) for k, v in sorted_kwargs.items(): if v.ndim == 1: v = jnp.reshape(v, v.shape + (1,)) cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v) - cell_kwargs[k] = _unflatten_cell_buffer( + cell_kwargs[k] = unflatten_cell_buffer( cell_kwargs[k], cells_per_side, dim) occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count) @@ -806,7 +806,7 @@ def cell_list_candidate_fn(cl_id_buffer, positionShape) -> Array: for dindex in _neighboring_cells(dim): if onp.all(dindex == 0): continue - cell_idx += [_shift_array(idx, dindex)] + cell_idx += [shift_array(idx, dindex)] cell_idx = jnp.concatenate(cell_idx, axis=-2) cell_idx = cell_idx[..., jnp.newaxis, :, :]