Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stochastic FW: optional batched user-provided functions #40

Open
matbesancon opened this issue Jan 18, 2021 · 4 comments
Open

Stochastic FW: optional batched user-provided functions #40

matbesancon opened this issue Jan 18, 2021 · 4 comments

Comments

@matbesancon
Copy link
Member

With the current SFW interface, users provide a function that processes one data point, batching happens a level higher when we call the provided functions.

One possibility would be to make users provide batched functions by default:

f_batched(θ, xs) = sum(f(θ, x_i) for x_i in xs)
g_batched(θ, xs) = sum(g(θ, x_i) for x_i in xs)

What they provide now is the equivalent of the functions f and g above.

@pokutta
Copy link
Member

pokutta commented Jan 18, 2021

good question - i would say we leave as is for now. the reason is that we need to thing how to best map e.g., variance-reduced methods as they need special batch sizes depending on the iteration.

@matbesancon
Copy link
Member Author

OK yes. Even with the alternative version, each iteration can control the batch size by picking the size of the xs list that is passed to {f/g}_batched

@pokutta
Copy link
Member

pokutta commented Jan 18, 2021

ok i will need some extra explanation tomorrow to discuss.

@matbesancon
Copy link
Member Author

So for now at the FW function level we have this:

compute_gradient(f, x, rng=rng, batch_size=batch_size)

At the compute_gradient level for f::StochasticObjective:

    rand_indices = if full_evaluation
        eachindex(f.xs)
    else
        rand(rng, eachindex(f.xs), batch_size)
    end
    return sum(f.grad(θ, f.xs[idx]) for idx in rand_indices)

So compute_gradient is the place where the default batching behaviour is defined, and calls f.grad on individual data points. The change in behaviour would be that compute_gradient passes down the rng, batch_size and full_evaluation arguments to f.grad, that is, the user-defined function, which itself implements the batching.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants