Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Parsing with Symbolics #678

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open

Improve Parsing with Symbolics #678

wants to merge 49 commits into from

Conversation

xtalax
Copy link
Member

@xtalax xtalax commented Apr 12, 2023

The current setup was written before we had all the useful tools present in Symbolics.jl, this PR takes advantage of these and the PDE helper functions present in PDEBase

@YichengDWu
Copy link
Member

YichengDWu commented May 20, 2023

I recently completely rewrote the parsing in YichengDWu/Sophon.jl#204, and it is now close to the ideal state I mentioned in #687. I used pattern matching from MacroTools there to transform expressions. Although I don't understand what you are doing here, I guess converting it to Symbolics-based rewriting would be straightforward. I hope this helps with this PR.

src/new_loss.jl Outdated
function generate_derivative_rules(eq, eqdata, dummyvars)
phi, u, coord, θ = dummyvars
@register_symbolic derivative(phi, u, coord, εs, order, θ)
rs = [[@rule $(Differential(x)^(~d)(w)) => derivative(phi, u, coord, get_εs(w), d, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should just be a single directional vector? get_ε(w). Mixed derivatives are handles outside of the derivative function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just trying to find a way to do this without having to generate rules for every combination of variables

@xtalax
Copy link
Member Author

xtalax commented May 26, 2023

@ChrisRackauckas @YichengDWu
Ok, here is what I expect will work, I'd appreciate a review and perhaps a comment on how mixed deriv rules can be generated without being exhaustive. @YichengDWu can you clarify what needs to change for mixed derivs to work as in your PR?

We have lost functionality for multioutput and integrals, to be re-added in future PRs, is this acceptable?

@YichengDWu
Copy link
Member

I'm not sure if I understand

@rule $(Differential(~x)^(~d::isinteger)(~w)) => derivative(phi, u, ~x, get_εs(~w), ~d, θ)

It looks like we should pass the coordinate to derivative not just ~x here? Anyway, mimicking this rule, mixed derivatives can be handle like the following (but not exactly):

@rule $(Differential(~x1)^(~d1::isinteger)($(Differential(~x2)^(~d2::isinteger)(~w))))
=>  derivative(phi, (cord_, θ_, phi_) ->derivative(phi_, u, ~x2, get_ε(~w), ~d2, θ_), u, ~x1, get_ε(~w), ~d1, θ)

Can you symbolic_discretize some test PDE to inspect the generated expression?

@YichengDWu
Copy link
Member

We have lost functionality for multioutput and integrals, to be re-added in future PRs, is this acceptable?

Or move the code into a submodule, and use Preferences to switch between the old and new backends, until the day when the old parsing backend is completely replaced?

@xtalax
Copy link
Member Author

xtalax commented May 31, 2023

This changes fundamentally a large part of the codebase, it will add a lot of complexity to do this, I'd rather keep them on different major versions instead

@ChrisRackauckas
Copy link
Member

If anyone needs something backported to the major that's fine, but I'd like to just out with the old and in with the new. On a major it's fine to be breaking, and I think if we have a functionality regression for a major change in the maintainership of this package it's fine. Honestly the integro-differential equation piece needs a bit more though on part of the definition anyways, so we should come back to it but I wouldn't let that hold it up. And the multioutput is a very minor feature that I wouldn't say is an important one.

@sathvikbhagavan
Copy link
Member

The loss functions generated are incorrect.

Running this example:

@parameters θ
@variables u(..)
Dθ = Differential(θ)

# 1D ODE
eq = (u(θ)) ~ θ^3 + 2 * θ +^2) * ((1 + 3 *^2)) / (1 + θ +^3))) -
                u(θ) *+ ((1 + 3 *^2)) / (1 + θ + θ^3)))

# Initial and boundary conditions
bcs = [u(0.0) ~ 1.0]

# Space and time domains
domains = Interval(0.0, 1.0)]

# Neural network
chain = Lux.Chain(Lux.Dense(1, 12, Flux.σ), Lux.Dense(12, 1))

discretization = NeuralPDE.PhysicsInformedNN(chain, GridTraining(0.1))

@named pdesys = PDESystem(eq, bcs, domains, [θ], [u(θ)])
prob = NeuralPDE.discretize(pdesys, discretization)
prob.f.f(prob.u0, nothing) # Calculate loss

Calculating the loss is erroring out.

I am getting Runtime Generated Function in https://github.com/xtalax/NeuralPDE.jl/blob/parsing/src/loss_function_generation.jl#L127 as

ex = :(function (var"##arg#8771009302963811065", θ_SYMBOL, phi, var"##arg#5964805160111424296")
      #= /home/sathvikbhagavan/.julia/packages/SymbolicUtils/ssQsQ/src/code.jl:373 =#
      #= /home/sathvikbhagavan/.julia/packages/SymbolicUtils/ssQsQ/src/code.jl:374 =#
      #= /home/sathvikbhagavan/.julia/packages/SymbolicUtils/ssQsQ/src/code.jl:375 =#
      begin
          θ = var"##arg#8771009302963811065"[1]
          nothing
      end
  end)

which does not look correct. I inserted show statements in https://github.com/xtalax/NeuralPDE.jl/blob/parsing/src/loss_function_generation.jl#L117 and https://github.com/xtalax/NeuralPDE.jl/blob/parsing/src/loss_function_generation.jl#L118 and found:

expr = 82570.18592591035(-phi(SymbolicUtils.BasicSymbolic{Real}[-6.055454452393343e-6 + θ], θ_SYMBOL) + phi(SymbolicUtils.BasicSymbolic{Real}[6.055454452393343e-6 + θ], θ_SYMBOL)) + ((-1 - 3^2))*^2)) / (1 + θ + θ^3) - 2θ + ((1 + 3^2)) / (1 + θ + θ^3) + θ)*phi(SymbolicUtils.BasicSymbolic{Real}[θ], θ_SYMBOL) -^3)
expr = nothing

So doing expr = swch(expr) in https://github.com/xtalax/NeuralPDE.jl/blob/parsing/src/loss_function_generation.jl#L118 is returning nothing as there is no switch in the first expression.

So, there might be a bug in how the symbolic rules are written. I am not familiar with SymbolicUtils that much to point out the bug.

@xtalax

cc: @ChrisRackauckas

@xtalax
Copy link
Member Author

xtalax commented Feb 9, 2024

Try wrapping the rule in a chain before postwalk, and maybe also a vector would be my guess

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants