Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhancement: support for jaxopt as a more robust stst solver #13

Closed
drhboss opened this issue May 6, 2024 · 5 comments
Closed

enhancement: support for jaxopt as a more robust stst solver #13

drhboss opened this issue May 6, 2024 · 5 comments

Comments

@drhboss
Copy link

drhboss commented May 6, 2024

I'm trying to solve a rather complex model (~100 equations), the problem is that the Newton solver is not robust enough if the initial values are a bit off, it wouldn't converge, is there any chance to implement more robust solvers or add support to utilize optimization routines from jaxopt to solve the model?

@gboehl
Copy link
Owner

gboehl commented May 6, 2024

Hi. Hm, I can't think of an application where the Newton-with-pseudoinverse solver would not be the best/most robust choice, but maybe it's just me. Did you check if the rank of the jacobian and the number of provided steady state values is aligned? This is normally the problem. I do not think that jaxopt/optax will be of much help because it rather targets optimization. Their root finding routines are pretty basic. But I'm happy to be convinced of the opposite.

Custom solvers are currently not supported, but I agree that this is a nice feature. What you could do is, you could take the steady state function and try to solve it with the solver of your choice. The steady state function gets compiled when running solve_stst and is then added to the model's context:

https://github.com/gboehl/econpizza/blob/master/econpizza/solvers/steady_state.py#L121

So you could do is:

#...parse your model. Then

# needed below
from econpizza.parser import d2jnp

# compile function but fail
res_stst = mod.solve_stst(raise_errors=False)

# initial values contain all variables and parameters that are not in fixed_values
init_vals = d2jnp(res_stst['initial_values']['guesses'])
# obtain the steady state function
stst_fun = mod['context']['func_stst']
# solve it
stst = your_root_finder(stst_fun, init_vals)

You should then either run solve_stst again with the correct initial values, or update mod['stst'] and mod['pars'] accordingly (see here). Let me know if this helps. Otherwise adding custom solvers should be quite straightforward.

@drhboss
Copy link
Author

drhboss commented May 15, 2024

Thank you for the notes, I was able to plugin my own solver and pass it to your newton_jax stst solver. everything works fine now.

def solve_singular(A, b):
    U, s, V = jnp.linalg.svd(A)
    # Reciprocal condition number threshold
    threshold = jnp.finfo(float).eps * max(A.shape) * s[0]
    s_inv = jnp.where(s > threshold, 1 / s, 0)
    pseudo_inverse = jnp.dot(V.T, s_inv[:, None] * U.T)
    x = jnp.dot(pseudo_inverse, b)
    return x

@drhboss drhboss closed this as completed May 15, 2024
@gboehl
Copy link
Owner

gboehl commented May 16, 2024

Great that it worked. Just out of interest, I was wondering what the difference between your solver and mine was. We both do nothing else than calculating the pseudo inverse. Is your threshold different?

@drhboss
Copy link
Author

drhboss commented May 16, 2024

that's right, they are essentially the same solver, this way it allows me to modify the threshold manually when the model doesn't solve, however, it is not working all the times. I do have a model that none of the solvers can solve it, but I have a solution obtained manually, if interested, I can share it with you.

@gboehl
Copy link
Owner

gboehl commented May 16, 2024

Sure!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants