Skip to content

Commit

Permalink
Switch to use a segment_sum rather than a loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
sschoenholz committed Nov 3, 2021
1 parent 200b9bb commit 25c8165
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions jax_md/partition.py
Expand Up @@ -132,14 +132,11 @@ def count_cell_filling(R: Array,

particle_index = jnp.array(R / cell_size, dtype=jnp.int64)
particle_hash = jnp.sum(particle_index * hash_multipliers, axis=1)
filling = jnp.zeros((cell_count,), dtype=jnp.int64)

def count(cell_hash, filling):
count = jnp.sum(particle_hash == cell_hash)
filling = ops.index_update(filling, ops.index[cell_hash], count)
return filling

return lax.fori_loop(0, cell_count, count, filling)
filling = ops.segment_sum(jnp.ones_like(particle_hash),
particle_hash,
cell_count)
return filling


def _is_variable_compatible_with_positions(R: Array) -> bool:
Expand Down

0 comments on commit 25c8165

Please sign in to comment.