-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vector-valued objective functions #1556
Comments
Hi @remypa , We haven't intended to support this, nor have we tested it before. However, we can put this on our roadmap for future versions. It's possible some of my changes in #1551 may make it work, but it's unlikely. What I don't understand exactly is why you'd need to do this, perhaps you can explain. My understanding is you have a set of objective function Typically you'd combine all of these in a single objective function, eg. by summing over Ultimately if you are able to combine everything you'd like to do in a single objective function, it should still be possible to use Note that you can always use |
Hi, thanks a lot for your reply. I am working on a minimax-type of problem : which I reformulate as Which is why I don't really recombine the |
I see. Yea we will have to work on improving the compatibility with tidy3d adjoint and nlopt for these sorts of problems. In the meantime, I might suggest you use a Some pseudo-code below but double check the specifics. def objective(x):
fs = jnp.array([f(x, i) for i in range(m)])
weights = jnp.nn.softmax(fs)
return jnp.sum(weights * fs) My intuition tells me that this should work reasonably well without needing to transform the problem to constraints. |
EDIT: forgot to |
The direct use of You don't need to use |
Hey @remypa, to add to the above - since everything in nlopt needs to happen outside of JAX anyway, it is perfectly fine to construct the constraint vector as @ianwilliamson described. In the simplest case (single wavelength, only differentiating w.r.t. a single argument), that would look something like this: def nlopt_epigraph_constraint(result: np.ndarray, x: np.ndarray, gd: np.ndarray) -> None:
t, v = x[0], x[1:]
# evaluate all objectives and get their gradients, assuming obj_fun_vgs is a list
# of gradient functions defined somewhere else of the form:
# d_obj = jax.value_and_grad(objective)
obj_vec, grad_vec = [], []
for obj_fun_vg in obj_fun_vgs:
obj_val, grad_val = obj_fun_vg(v)
obj_vec.append(obj_val)
grad_vec.append(grad_val)
if gd.size > 0:
gd[:, 0] = -1
gd[:, 1:] = np.asarray(grad_vec)
result[:] = np.asarray(obj_vec) - t You can also parallelize the evaluation of the objective functions in that loop with something like |
Hi, thanks for all your inputs/suggestions. @ianwilliamson , @yaugenst : that is the way I am doing it at the moment. It does work, but the difficulty is that I think I can't use Tidy3D's batch infrastructure out of the box, hence my initial post. @tylerflex : I have started looking into softmax. My initial results are promising. Cheers. |
I am using tidy3D to perform multi-objectives inverse design. To speed things up, I'd like to run batches with web.run_async and return the different objectives as a vector value (I need the different individual values to define multiple constraints to y optimisation problem).
To do that, I can't use jax.value_and_grad (which is limited to scalar functions). Instead, I need to use jax.jacrev or jax.jacfwd.
But when I do that, I run into ConcretizationTypeError if I'm using a FieldMonitor, or a TracerArrayConversionError if I'm using a ModeMonitor.
A simple way to reproduce the problem is to extend the tutorial at https://www.flexcompute.com/tidy3d/examples/notebooks/AdjointPlugin1Intro/ , with
This is not a vector-valued objective function, but the problem is the same:
The text was updated successfully, but these errors were encountered: