Skip to content

Commit

Permalink
Make functions within partition visible for downstream usecases.
Browse files Browse the repository at this point in the history
  • Loading branch information
amilmerchant committed Feb 21, 2024
1 parent 31dd113 commit 56f0c35
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion jax_md/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,11 @@ def _estimate_cell_capacity(position: Array,
return int(cell_capacity * buffer_size_multiplier)


<<<<<<< Updated upstream
def _shift_array(arr: Array, dindex: Array) -> Array:
=======
def shift_array(arr: Array, dindex: Array) -> Array:
>>>>>>> Stashed changes
if len(dindex) == 2:
dx, dy = dindex
dz = 0
Expand All @@ -231,7 +235,11 @@ def _shift_array(arr: Array, dindex: Array) -> Array:
return arr


<<<<<<< Updated upstream
def _unflatten_cell_buffer(arr: Array,
=======
def unflatten_cell_buffer(arr: Array,
>>>>>>> Stashed changes
cells_per_side: Array,
dim: int) -> Array:
if (isinstance(cells_per_side, int) or
Expand Down Expand Up @@ -380,14 +388,23 @@ def cell_list_fn(position: Array,
cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
sorted_id = jnp.reshape(sorted_id, (N, 1))
cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
<<<<<<< Updated upstream
cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim)
cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)
=======
cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim)
cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim)
>>>>>>> Stashed changes

for k, v in sorted_kwargs.items():
if v.ndim == 1:
v = jnp.reshape(v, v.shape + (1,))
cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
<<<<<<< Updated upstream
cell_kwargs[k] = _unflatten_cell_buffer(
=======
cell_kwargs[k] = unflatten_cell_buffer(
>>>>>>> Stashed changes
cell_kwargs[k], cells_per_side, dim)

occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
Expand Down Expand Up @@ -806,7 +823,7 @@ def cell_list_candidate_fn(cl_id_buffer, positionShape) -> Array:
for dindex in _neighboring_cells(dim):
if onp.all(dindex == 0):
continue
cell_idx += [_shift_array(idx, dindex)]
cell_idx += [shift_array(idx, dindex)]

cell_idx = jnp.concatenate(cell_idx, axis=-2)
cell_idx = cell_idx[..., jnp.newaxis, :, :]
Expand Down

0 comments on commit 56f0c35

Please sign in to comment.