Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Idea: flat transform #94

Open
scheidan opened this issue Aug 3, 2022 · 5 comments
Open

Feature Idea: flat transform #94

scheidan opened this issue Aug 3, 2022 · 5 comments

Comments

@scheidan
Copy link
Contributor

scheidan commented Aug 3, 2022

It would be useful if transform would have an option, so that the result remains a flat vector:

transform(t, x, keep_flat=true)

A good use case is converting MCMC samples in MCMCChains.Chains objects:

import MCMCChains

samp = rand(1000, 5)            # we would get them this from a MCMC algorithm

# we get an array of named tuples, which is great to define the model but difficult to convert a `Chain`.
samp_trans1 = mapslices(s -> transform(t, s), samp, dims=2)
MCMCChains.Chains(samp_trans)              # fails

# with the new argument we would get an array
samp_trans2 = mapslice(s -> transform(t, s, keep_flat=true), samp, dims=2)
MCMCChains.Chains(samp_trans2)              # that would work

This seem related to #13

@tpapp
Copy link
Owner

tpapp commented Aug 4, 2022

What is the format of samp_trans2 here that you would expect? I am not familiar with MCMCChains.Chains.

@scheidan
Copy link
Contributor Author

scheidan commented Aug 4, 2022

MCMCChains.Chain expects an Array of dimensions iterations × n_parameters (or iterations × n_parameters × n_chains).

Having a flat transform would make the construction of such an Array quite easy. We would need to be careful with the length:

t = as((a = asℝ,
        b = as(Vector, as(Real, 0, 1), 2),
        c = UnitVector(3)))

x = randn(dimension(t))  # length(x) == 5
transform(t, x)  # -> tuple
transform(t, x, keep_flat=true))  # -> vector of length(6) != dimension(t)

@tpapp
Copy link
Owner

tpapp commented Aug 6, 2022

Thanks, I get it. It should be relatively easy to flatten transformed values:

flatten(x::Real) = [x]
flatten(x::AbstractArray) = vec(x)
flatten(x::Tuple) = mapreduce(flatten, vcat, x)
flatten(x::NamedTuple) = mapreduce(flatten, vcat, values(x))

z = (a = 1.0, b = [2.0, 3.0], c = (d = 4.0, e = 5.0))

flatten(z)

can deal with everything TransformVariables can dish out at the moment. (The code above necessarily allocates and is quite suboptimal, in the ideal case this would be done with views like https://github.com/JuliaArrays/StackViews.jl).

Or would you prefer transforming directly to a flat vector for efficiency? I will keep this in mind for the next refactoring (which is coming up soon).

@tpapp
Copy link
Owner

tpapp commented Aug 6, 2022

Also, an ideal API would give column names, such as [:a, :b_1, :b_2, :c_d, :c_e] or similar.

@scheidan
Copy link
Contributor Author

scheidan commented Aug 8, 2022

Getting meaningful names would be very helpful!

MCMCChains.jl has some support for names with brackets, for variables from arrays e.g. "x[1,1]", "x[1,2]"
https://beta.turing.ml/MCMCChains.jl/stable/getting-started/#Groups-of-parameters

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants