Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector-valued objective functions #1556

Open
remypa opened this issue Mar 17, 2024 · 7 comments
Open

Vector-valued objective functions #1556

remypa opened this issue Mar 17, 2024 · 7 comments
Assignees

Comments

@remypa
Copy link

remypa commented Mar 17, 2024

I am using tidy3D to perform multi-objectives inverse design. To speed things up, I'd like to run batches with web.run_async and return the different objectives as a vector value (I need the different individual values to define multiple constraints to y optimisation problem).

To do that, I can't use jax.value_and_grad (which is limited to scalar functions). Instead, I need to use jax.jacrev or jax.jacfwd.

But when I do that, I run into ConcretizationTypeError if I'm using a FieldMonitor, or a TracerArrayConversionError if I'm using a ModeMonitor.

A simple way to reproduce the problem is to extend the tutorial at https://www.flexcompute.com/tidy3d/examples/notebooks/AdjointPlugin1Intro/ , with

jac = jax.jacrev(power, argnums=(0,1,2))
d_power = jac(center0, size0, eps0)

This is not a vector-valued objective function, but the problem is the same:

Traceback (most recent call last):

  File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:196 in _run_module_as_main
    return _run_code(code, main_globals, None,

  File ~/micromamba/envs/tidy3d/lib/python3.10/runpy.py:86 in _run_code
    exec(code, run_globals)

  File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/__main__.py:24
    start.main()

  File ~/.local/lib/python3.10/site-packages/spyder_kernels/console/start.py:340 in main
    kernel.start()

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelapp.py:724 in start
    self.io_loop.start()

  File ~/.local/lib/python3.10/site-packages/tornado/platform/asyncio.py:215 in start
    self.asyncio_loop.run_forever()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:595 in run_forever
    self._run_once()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/base_events.py:1881 in _run_once
    handle._run()

  File ~/micromamba/envs/tidy3d/lib/python3.10/asyncio/events.py:80 in _run
    self._context.run(self._callback, *self._args)

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:512 in dispatch_queue
    await self.process_one()

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:501 in process_one
    await dispatch(*args)

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:408 in dispatch_shell
    await result

  File ~/.local/lib/python3.10/site-packages/ipykernel/kernelbase.py:731 in execute_request
    reply_content = await reply_content

  File ~/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py:417 in do_execute
    res = shell.run_cell(

  File ~/.local/lib/python3.10/site-packages/ipykernel/zmqshell.py:540 in run_cell
    return super().run_cell(*args, **kwargs)

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:2945 in run_cell
    result = self._run_cell(

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3000 in _run_cell
    return runner(coro)

  File ~/.local/lib/python3.10/site-packages/IPython/core/async_helpers.py:129 in _pseudo_sync_runner
    coro.send(None)

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3203 in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3382 in run_ast_nodes
    if await self.run_code(code, result, async_=asy):

  File ~/.local/lib/python3.10/site-packages/IPython/core/interactiveshell.py:3442 in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)

  Cell In[7], line 2
    d_power = jac(center0, size0, eps0)

  File ~/rare_earth_ions/simulation/problem_jacobian.py:100 in power
    jax_sim_data = run_adjoint(jax_sim, task_name="adjoint_power", verbose=True)

JaxStackTraceBeforeTransformation: jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
  /home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------


The above exception was the direct cause of the following exception:

Traceback (most recent call last):

  Cell In[7], line 2
    d_power = jac(center0, size0, eps0)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:951 in jacfun
    jac = vmap(pullback)(_std_basis(y))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/traceback_util.py:166 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:1258 in vmap_f
    out_flat = batching.batch(

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
    return self.fun(*args, **kw)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/api.py:2161 in _vjp_pullback_wrapper
    ans = fun(*args)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/tree_util.py:361 in __call__
    return self.fun(*args, **kw)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:147 in unbound_vjp
    arg_cts = backward_pass(jaxpr, reduce_axes, True, consts, dummy_args, cts)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:254 in backward_pass
    cts_out = get_primitive_transpose(eqn.primitive)(

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:761 in _custom_lin_transpose
    cts_in = bwd(*res, *cts_out)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/custom_derivatives.py:769 in <lambda>
    bwd_ = lambda *args: bwd(*args)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/linear_util.py:188 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/web.py:169 in run_bwd
    jax_sim_adj = sim_data_vjp.make_adjoint_simulation(fwidth=fwidth_adj, run_time=run_time_adj)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/sim_data.py:175 in make_adjoint_simulation
    for adj_source in mnt_data_vjp.to_adjoint_sources(fwidth=fwidth):

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/monitor_data.py:83 in to_adjoint_sources
    amps, sel_coords = self.amps.nonzero_val_coords

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
    computed_value = prop(self)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:448 in nonzero_val_coords
    values = np.nan_to_num(self.as_ndarray)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/components/base.py:51 in cached_property_getter
    computed_value = prop(self)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/tidy3d/plugins/adjoint/components/data/data_array.py:131 in as_ndarray
    return np.array(self.values)

  File ~/micromamba/envs/tidy3d/lib/python3.10/site-packages/jax/_src/core.py:611 in __array__
    raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape complex64[1,1].
This BatchTracer with object id 140334367546704 was created on line:
  /home/pr/rare_earth_ions/simulation/problem_jacobian.py:94 (compute_power)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@remypa remypa added the feature label Mar 17, 2024
@tylerflex
Copy link
Collaborator

Hi @remypa ,

We haven't intended to support this, nor have we tested it before. However, we can put this on our roadmap for future versions. It's possible some of my changes in #1551 may make it work, but it's unlikely.

What I don't understand exactly is why you'd need to do this, perhaps you can explain. My understanding is you have a set of objective function f_i, each of which depends on a tidy3d simulation.

Typically you'd combine all of these in a single objective function, eg. by summing over f_i. In that case you could still use value_and_grad. However, you'd like to store the values of f_i and then do some additional processing with them? Can you explain that a bit more?

Ultimately if you are able to combine everything you'd like to do in a single objective function, it should still be possible to use value_and_grad.

Note that you can always use has_aux in the value_and_grad call if you simply need to store these f_i values to process outside of the loop. I'd recommend seeing this tutorial, cells [17][18], for an example.

@tylerflex tylerflex self-assigned this Mar 18, 2024
@remypa
Copy link
Author

remypa commented Mar 18, 2024

Hi,

thanks a lot for your reply.

I am working on a minimax-type of problem :

$$\min_{x \in \mathbb{R}^{n}} \max(f_1(x), f_2(x), .... f_m(x))$$

which I reformulate as

$$\displaylines{\min_{x \in \mathbb{R}^{n}, t \in \mathbb{R}} t \\\ \text{s.t } t \ge f_k(x) \text{ for } k=1, 2, ...., m}$$

as per https://nlopt.readthedocs.io/en/latest/NLopt_Introduction/#equivalent-formulations-of-optimization-problems.

Which is why I don't really recombine the f_i into a single objective function.

@tylerflex
Copy link
Collaborator

tylerflex commented Mar 18, 2024

I see. Yea we will have to work on improving the compatibility with tidy3d adjoint and nlopt for these sorts of problems.

In the meantime, I might suggest you use a softmax function such as jnp.nn.softmax. You can use this to weight your f_i to preferentially penalize the maximum one, such that your objective function is still differentiable.

Some pseudo-code below but double check the specifics.

def objective(x):
    fs = jnp.array([f(x, i) for i in range(m)])
    weights = jnp.nn.softmax(fs)
    return jnp.sum(weights * fs)

My intuition tells me that this should work reasonably well without needing to transform the problem to constraints.

@tylerflex
Copy link
Collaborator

EDIT: forgot to jnp.sum in the return. fixed.

@ianwilliamson
Copy link

The direct use of jax.jacrev(vector_valued_fn) requires that all operations in vector_valued_fn() have batching rules defined. If you look at the implementation of jax.jacrev, you'll see that it just vmaps over jax.vjp. Generally, this means that one needs to define JAX primitives with batching rules, which is different from the strategy of defining jax.custom_vjp rules that wrap non-JAX code (Tidy3D's approach). There is no way to avoid leaking JAX types into the wrapped code when higher-order JAX transformations are used on a custom_vjp, which is why you see the error in the OP.

You don't need to use jax.jacrev, and it might even be less convenient since it does not return the vector value (only the Jacobian). You can instead manually manage the construction of the constraint vector Jacobian using a Python loop (or whatever batching Tidy3D provides). This would allow you to to perform the epigraph minimax style of optimization that has been popularized by Meep.

@yaugenst
Copy link
Contributor

Hey @remypa,

to add to the above - since everything in nlopt needs to happen outside of JAX anyway, it is perfectly fine to construct the constraint vector as @ianwilliamson described. In the simplest case (single wavelength, only differentiating w.r.t. a single argument), that would look something like this:

def nlopt_epigraph_constraint(result: np.ndarray, x: np.ndarray, gd: np.ndarray) -> None:
        t, v = x[0], x[1:]

        # evaluate all objectives and get their gradients, assuming obj_fun_vgs is a list
        # of gradient functions defined somewhere else of the form:
        # d_obj = jax.value_and_grad(objective)
        obj_vec, grad_vec = [], []
        for obj_fun_vg in obj_fun_vgs:
                obj_val, grad_val = obj_fun_vg(v)
                obj_vec.append(obj_val)
                grad_vec.append(grad_val)

        if gd.size > 0:
            gd[:, 0] = -1
            gd[:, 1:] = np.asarray(grad_vec)

        result[:] = np.asarray(obj_vec) - t

You can also parallelize the evaluation of the objective functions in that loop with something like async/await, or maybe it's possible via tidy3d's built-in batching, although I don't know.
On the one hand, it might be nice to support this out of the box, but it would really just mean moving the bookkeeping (assembling the constraint vector) to somewhere inside tidy3d's adjoint module. How that needs to be handled exactly depends on the optimization package. For example, how nlopt handles this might differ from scipy or IPOPT, so there is not a general solution there.

@remypa
Copy link
Author

remypa commented Mar 19, 2024

Hi,

thanks for all your inputs/suggestions.

@ianwilliamson , @yaugenst : that is the way I am doing it at the moment. It does work, but the difficulty is that I think I can't use Tidy3D's batch infrastructure out of the box, hence my initial post.

@tylerflex : I have started looking into softmax. My initial results are promising.

Cheers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants