Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-Introduce ShardedForm for large expressions #937

Open
lassepe opened this issue Jul 26, 2023 · 5 comments · May be fixed by #954
Open

Re-Introduce ShardedForm for large expressions #937

lassepe opened this issue Jul 26, 2023 · 5 comments · May be fixed by #954

Comments

@lassepe
Copy link

lassepe commented Jul 26, 2023

The removal of ShardedForm for large arrays in 616ef52 has caused a major performance regression in one of my projects. Compilation time is now about 10x longer. Beyond that, previously, calling the resulting function from 23 distributed workers in parallel worked just fine. Now, it quickly runs out of memory on a machine with 64GB of ram; I guess because all workers have to compile the function upon first call and compilation is more memory intensive for the SerialForm.

Is there a way to default again to ShardedForms for large functions or is this a fundamental limitation of RuntimeGeneratedFunctions.jl?

@lassepe lassepe changed the title Major regression in compile time and memory footprint Re-Introduce ShardedForm for large expressions Jul 26, 2023
@ChrisRackauckas
Copy link
Member

Someone just needs to fix it. We cannot default to it if it's not correct. If you're willing to fix its dependency analysis then we'd be happy to re-enable it.

@lassepe
Copy link
Author

lassepe commented Aug 9, 2023

I am happy to take a look. I have little exposure in this area; so no idea if I can be of actual use here. Are there any more pointers/evidence of what exactly seems to be the issue?

@ChrisRackauckas
Copy link
Member

The observed equations are not appended to the front of the sharded equations so it errors with any observables. At least that's what someone mentioned to me 2 weeks ago (@shashi?)

@shashi shashi linked a pull request Aug 16, 2023 that will close this issue
@shashi
Copy link
Member

shashi commented Aug 16, 2023

I would like to see a reproducer from you guys for the multithreading deadlock. The problem might have gone away now. @wsphillips was saying it gave wrong answers. But just from the code, I don't see a chance of dead lock or race conditions, unless you are passing in indices into build_function with repeated indices, in which case, even the serial version would be wrong.

@wsphillips
Copy link
Collaborator

wsphillips commented Aug 16, 2023

The fix from @shashi solves the observable scoping issue. But when I tested it with some Conductor.jl models it returns solutions that are wrong if in MultithreadedForm() (serially executed sharded form is fine).

You can use this MWE I adapted from the MTK docs to reproduce:

Edit: using Shashi's opaque closure PR

using ModelingToolkit, Plots, OrdinaryDiffEq, LinearAlgebra
using Symbolics: scalarize

@variables t
D = Differential(t)

function Mass(; name, m = 1.0, xy = [0.0, 0.0], u = [0.0, 0.0])
    ps = @parameters m = m
    sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
    eqs = scalarize(D.(pos) .~ v)
    ODESystem(eqs, t, [pos..., v...], ps; name)
end

function Spring(; name, k = 1e4, l = 1.0)
    ps = @parameters k=k l=l
    @variables x(t), dir(t)[1:2]
    ODESystem(Equation[], t, [x, dir...], ps; name)
end

function connect_spring(spring, a, b)
    [spring.x ~ norm(scalarize(a .- b))
        scalarize(spring.dir .~ scalarize(a .- b))]
end

function spring_force(spring)
    -spring.k .* scalarize(spring.dir) .* (spring.x - spring.l) ./ spring.x
end

m = 1.0
xy = [1.0, -1.0]
k = 1e4
l = 1.0
center = [0.0, 0.0]
g = [0.0, -9.81]
@named mass = Mass(m = m, xy = xy)
@named spring = Spring(k = k, l = l)

eqs = [connect_spring(spring, mass.pos, center)
    scalarize(D.(mass.v) .~ spring_force(spring) / mass.m .+ g)]

@named _model = ODESystem(eqs, t, [spring.x; spring.dir; mass.pos], [])
@named model = compose(_model, mass, spring)
sys = structural_simplify(model)

# if parallel = `Symbolics.ShardedForm(2,2)` or default serial form this works fine
prob = ODEProblem(sys, [], (0.0, 3.0); parallel = Symbolics.MultithreadedForm(2,2))
sol = solve(prob, Rosenbrock23())
plot(sol) # no oscillations/wrong output when multithreaded

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants