Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for user-supplied RNG state in all interfaces #520

Open
wants to merge 9 commits into
base: modular-rng
Choose a base branch
from

Conversation

bgroenks96
Copy link

@bgroenks96 bgroenks96 commented Feb 27, 2024

Updates the Gen function interface to use the standard Julia pattern for user-supplied RNG states, i.e:

myfunc(args...) = myfunc(default_rng(), args...)
myfunc(rng::AbstractRNG, args...) = ...

This is applied to all function interfaces which use rng.

Inference algorithms provide instead a keyword argument rng which tends to be more common for higher level function interfaces.

Note that this PR should be fully backwards compatible with all tests and existing Gen code since method dispatches with default_rng() are universally provided.

Resolves #33

Updates the Gen function interface to use the standard Julia pattern for user-supplied RNG states, i.e:

myfunc(args...) = myfunc(default_rng(), args...)
myfunc(rng::AbstractRNG, args...) = ...

This is applied to all function interfaces which use rng.

Inference algorithms provide instead a keyword argument `rng` which tends to be more common for higher level function interfaces.
@fsaad
Copy link
Collaborator

fsaad commented Feb 27, 2024

Hi @bgroenks96, thank you for this thorough PR. I support merging these improvements to the API.

Could you please take a look at the failing ContinuousIntegration tests? There are appears to be a "Not implemented" error in one of the tests.
https://github.com/probcomp/Gen.jl/actions/runs/8067611177/job/22038710386?pr=520

@alex-lew @ztangent Please also take a look.

@bgroenks96
Copy link
Author

bgroenks96 commented Feb 27, 2024

@fsaad Fixed in the last two commits. All tests are passing for me locally.

@bgroenks96
Copy link
Author

Regarding the fix in e39459b, I am not sure that I fully understand this part of the API, but I am assuming by the name that deterministic functions should not need access to the RNG. If this is wrong, then we would unfortunately need to break this part of the API, I think.

@ztangent
Copy link
Member

Thanks for this PR! I'll review this in the next couple of days, but intuitively this seems like the right way to modify the interfaces and I've checked that other implementations of Gen (e.g. the work-in-progress JAX implementation) also adopt a similar interface for control over the RNG / RNG seed.

Copy link
Member

@ztangent ztangent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this PR! I've left comments in most of the places where changes are needed. Here's a summary of the changes that still need to be made:

  • Modify GF...State types to become parametric types
  • For the dynamic DSl, make sure in traceat calls that the RNG is passed down to recursive calls to GFI methods like simulate, generate, etc.
  • Add a fallback implementation (with a warning) from the RNG-version of the GFI methods to the non-RNG version of the GFI methods. This is to prevent breaking existing code outside of Gen.jl that does not use the non-RNG version (same goes for the definition of random).
  • Update the GFI docstrings or documentation to mention that a custom RNG can be provided by the user.
  • For the static modeling language, ensure that the RNG is passed down to nested GFI calls for all GFI methods.
  • For the static modeling language, replace the rng variable name with a globally gen-symed variable name, to avoid name collisions.

In addition, it appears that these parts of the Gen.jl library need to be updated to make use of custom RNGs:

  • The rest of the combinators, like Map, Unfold, and CallAt,
  • Most of the inference library:
    • Importance (re)sampling
    • Particle filtering
    • Trace translators
    • Elliptical slice sampling
    • Trace kernel DSL, since it allows users to write code that randomly decides between MCMC kernels.

I understand that this is a fair amount of work, and that we should possibly break it up into separate PRs. One way to do this might be for us to create a separate branch of Gen dedicated to merging this broader set of changes, and then each PR can focus on supporting custom RNGs for various portions of the code base. If that sounds good to you, I can go ahead and create a branch called modular-rng (as you've called your branch).

@@ -47,10 +47,13 @@ accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad

mutable struct GFUntracedState
params::Dict{Symbol,Any}
rng::AbstractRNG
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change all these GF...State structs to be parametric in the type of the RNG, to avoid potential performance regressions due to type instability. To be specific, I would replace this with:

mutable struct GFUntracedState{R <: AbstractRNG}
    params::Dict{Symbol, Any}
    rng::R
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -85,7 +88,7 @@ end
gen_fn(args...)

@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
random(dist, args...)
random(state.rng, dist, args...)

@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) =
gen_fn(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

splice should also pass state.rng as the first argument to gen_fn.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -85,7 +88,7 @@ end
gen_fn(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should pass state.rng as the first argument to gen_fn.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -78,9 +79,12 @@ function splice(state::GFGenerateState, gen_fn::DynamicDSLFunction,
retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 59, the recursive call to generate needs to pass state.rng to the callee function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -55,8 +56,8 @@ function splice(state::GFProposeState, gen_fn::DynamicDSLFunction, args::Tuple)
retval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 40, state.rng needs to be passed to the recursive call to propose.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -59,29 +59,35 @@ function assess(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple, choices::ChoiceMap
(weight, value)
end

function propose(gen_fn::ChoiceAtCombinator{T,K}, args::Tuple) where {T,K}
propose(gen_fn::ChoiceAtCombinator, args::Tuple) = propose(default_rng(), gen_fn, args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the CallAtCombinator in call_at.jl also needs to be updated to pass down rng to any nested generative function calls. Same goes for Map, Unfold etc.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -131,8 +131,9 @@ end
# TODO
accepts_output_grad(::Recurse) = false

function (gen_fn::Recurse)(args...)
(_, _, retval) = propose(gen_fn, args)
(gen_fn::Recurse)(args...) = gen_fn(default_rng(), args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make the same changes for Map and Unfold (which are much more widely used combinators).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Calls `random` with the default global RNG.
"""
random(dist::Distribution, args...) = random(default_rng(), dist, args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For similar reasons to the comments I made on the definition of simulate, I believe we also need a fallback in the other direction, along with a warning. There's a fair amount of custom distributions that people have written with Gen (see e.g. https://github.com/probcomp/GenDistributions.jl), and we need to add a fallback from a version with the RNG to the version without the RNG.

Also, I don't think we need another docstring for this definition of random since this random is already documented above. The original docstring should just be modified to note that the user can supply a custom RNG.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -25,7 +25,7 @@ function process!(state::StaticIRSimulateState, node::RandomChoiceNode, options)
incr = gensym("logpdf")
addr = QuoteNode(node.addr)
dist = QuoteNode(node.dist)
push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))($dist, $(args...))))
push!(state.stmts, :($(node.name) = $(GlobalRef(Gen, :random))(rng, $dist, $(args...))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to this change, we need to make sure that recursive calls to simulate also pass along the RNG. e.g. on line 43 of this file, there is a recursive call to simulate that should be passed rng as the first argument.

This should be done for all of the GFI functions.

Also, I'll make this point again later below, but to be safe and avoid name collisions, I believe rng here should be replaced with a globally gen-symed variable name called STATIC_RNG. Otherwise, if the user happens to define their own variable called rng in their function definition, the generated code may end up being buggy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -63,18 +63,18 @@ function generate_generative_function(ir::StaticIR, name::Symbol, options::Stati
$(GlobalRef(Gen, :get_options))(::Type{$gen_fn_type_name}) = $(QuoteNode(options))
# Generate GFI definitions
(gen_fn::$gen_fn_type_name)(args...) = $(GlobalRef(Gen, :propose))(gen_fn, args)[3]
@generated function $(GlobalRef(Gen, :simulate))(gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)))
@generated function $(GlobalRef(Gen, :simulate))(rng::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, rather than calling the variable rng, I believe it will be safer to add this global definition somewhere near the top of this file:

"Global reference to the RNG variable for the static modeling language."
const STATIC_RNG = gensym("rng")

And then change the above line to:

@generated function $(GlobalRef(Gen, :simulate))($STATIC_RNG::$AbstractRNG, gen_fn::$gen_fn_type_name, args::$(QuoteNode(Tuple)))

$STATIC_RNG should then be used whenever generating code that needs some reference to the RNG.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bgroenks96
Copy link
Author

Thanks for the detailed review, @ztangent . I was afraid this would turn out to be more complicated than it initially looked...

I will try to go through the comments later this week. In the meantime, it would be good to go ahead and create that branch, and then I can change the PR to target this instead of master.

@bgroenks96
Copy link
Author

bgroenks96 commented Apr 29, 2024

Hi @ztangent, I apologize for the long delay. I had some other more pressing deadlines to attend to.

I think that I have addressed your first set of comments, as well as the issues with the combinators. I still need to look more closely at the inference algorithms.

Please let me know if I have missed anything or if I did not fully address any of the issues.

EDIT: Note that I have verified that all tests are passing on my machine (as of 0539d93).

@ztangent ztangent changed the base branch from master to modular-rng May 1, 2024 01:34
@ztangent
Copy link
Member

ztangent commented May 1, 2024

Awesome, thank you! I should have time to look more at this the week after next.

I've also created the modular-rng branch as discussed before, so we can make the changes to the inference library as a separate PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Determine policy for controlling entropy
3 participants