Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-step loading + auto-diff #5

Open
SNMS95 opened this issue Oct 3, 2023 · 2 comments
Open

Multi-step loading + auto-diff #5

SNMS95 opened this issue Oct 3, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@SNMS95
Copy link
Collaborator

SNMS95 commented Oct 3, 2023

  • In the plasticity example, the load is applied in smaller steps. This is beneficial for easy convergence of iterative solvers for other cases as well.
  • Solution:
    • Just as ad_wrapper performs implicit differentiation through a non-linear solve, we should allow either implicit differentiation through load steps!
@SNMS95 SNMS95 added the enhancement New feature or request label Oct 3, 2023
@SNMS95
Copy link
Collaborator Author

SNMS95 commented Oct 11, 2023

def non_trace_force_stepping_fn(design, force_magnitues):
    u_init=0.0
    for i in range(n_steps):
        force_magnitude = force_magnitudes[i]
        # Update value function
        problem_bcs = available_problem_bcs.ProblemBCsDatabase.get_problem('cantilever_2D', box_domain,
                                                                            force_magnitude=force_magnitude)
        fe_problem.neumann_bc_info = problem_bcs.neumann_bc_info

        # Solve the forward problem
        fwd_pred = ad_wrapper(problem=fe_problem, linear=True, use_petsc=False,
                              u_init=u_init)
        u = fwd_pred(design)
        u_init = u
    return u, fwd_pred

def pipeline_fn(params, state_of_params={}):
    design, state_of_params = parametrizer_fn_with_x(params=params,
                                              state=state_of_params)
    design = cone_filter_mapping(design)
    u_final, fwd_pred = non_trace_force_stepping_fn(jax.lax.stop_gradient(design), force_magnitudes)
    fe_solution = fwd_pred([design], u_init=u_final)
    objective_val = objective_fn(fe_solution)
    constraint_val = constraint_fn(design)
    return {
        "fe_solution": fe_solution,
        "design": design,
        "objective": objective_val,
        "constraint": constraint_val,
        "nn_state": state_of_params,
         }

This seems like an easy solution.
Changes to be made:

  1. Adjust ad_wrapper to accept u_init_guess [Modify JAX-FEM]
  2. Expose force_magnitude in problem_bcs [TO-JAX]

@acse-itk22

@h-vijayakumaran
Copy link

I think the jax-am example for plasticity would be a good starting point

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants