Skip to content

Commit

Permalink
Merge pull request #300 from amilmerchant/main
Browse files Browse the repository at this point in the history
Make functions within partition.py public-access for downstream usecases.
  • Loading branch information
ekindogus committed Feb 21, 2024
2 parents 4ee99f6 + 5742e23 commit e08fb81
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions jax_md/partition.py
Expand Up @@ -206,7 +206,7 @@ def _estimate_cell_capacity(position: Array,
return int(cell_capacity * buffer_size_multiplier)


def _shift_array(arr: Array, dindex: Array) -> Array:
def shift_array(arr: Array, dindex: Array) -> Array:
if len(dindex) == 2:
dx, dy = dindex
dz = 0
Expand All @@ -231,7 +231,7 @@ def _shift_array(arr: Array, dindex: Array) -> Array:
return arr


def _unflatten_cell_buffer(arr: Array,
def unflatten_cell_buffer(arr: Array,
cells_per_side: Array,
dim: int) -> Array:
if (isinstance(cells_per_side, int) or
Expand Down Expand Up @@ -380,14 +380,14 @@ def cell_list_fn(position: Array,
cell_position = cell_position.at[sorted_cell_id].set(sorted_position)
sorted_id = jnp.reshape(sorted_id, (N, 1))
cell_id = cell_id.at[sorted_cell_id].set(sorted_id)
cell_position = _unflatten_cell_buffer(cell_position, cells_per_side, dim)
cell_id = _unflatten_cell_buffer(cell_id, cells_per_side, dim)
cell_position = unflatten_cell_buffer(cell_position, cells_per_side, dim)
cell_id = unflatten_cell_buffer(cell_id, cells_per_side, dim)

for k, v in sorted_kwargs.items():
if v.ndim == 1:
v = jnp.reshape(v, v.shape + (1,))
cell_kwargs[k] = cell_kwargs[k].at[sorted_cell_id].set(v)
cell_kwargs[k] = _unflatten_cell_buffer(
cell_kwargs[k] = unflatten_cell_buffer(
cell_kwargs[k], cells_per_side, dim)

occupancy = ops.segment_sum(jnp.ones_like(hashes), hashes, cell_count)
Expand Down Expand Up @@ -806,7 +806,7 @@ def cell_list_candidate_fn(cl_id_buffer, positionShape) -> Array:
for dindex in _neighboring_cells(dim):
if onp.all(dindex == 0):
continue
cell_idx += [_shift_array(idx, dindex)]
cell_idx += [shift_array(idx, dindex)]

cell_idx = jnp.concatenate(cell_idx, axis=-2)
cell_idx = cell_idx[..., jnp.newaxis, :, :]
Expand Down

0 comments on commit e08fb81

Please sign in to comment.