Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_sample_in_masked_pareto_front does not return random_key #50

Open
Lookatator opened this issue Jul 5, 2022 · 0 comments
Open

_sample_in_masked_pareto_front does not return random_key #50

Lookatator opened this issue Jul 5, 2022 · 0 comments
Assignees
Labels
enhancement New feature or request

Comments

@Lookatator
Copy link
Member

The function _sample_in_masked_pareto_front does not return an updated random_key (contrary to all other functions in the project).

@partial(jax.jit, static_argnames=("num_samples",))
def _sample_in_masked_pareto_front(
self,
pareto_front_genotypes: Genotype,
mask: jnp.ndarray,
random_key: RNGKey,
) -> Genotype:
"""Sample num_samples elements in masked pareto front.
Note: do not retrieve a random key because this function
is to be vmapped. The public method that uses this function
will return a random key

This can probably be solved by using the in_axis and out_axis arguments of jax.vmap

@Lookatator Lookatator added the enhancement New feature or request label Jul 5, 2022
@Lookatator Lookatator self-assigned this Jul 5, 2022
@Lookatator Lookatator added this to the QDax v0.1.0 milestone Jul 5, 2022
@Lookatator Lookatator removed this from the QDax v0.1.0 milestone Jul 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant