Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert a BrainPy model to process batched input by jax.vmap #608

Open
CloudyDory opened this issue Jan 31, 2024 · 2 comments
Open

Convert a BrainPy model to process batched input by jax.vmap #608

CloudyDory opened this issue Jan 31, 2024 · 2 comments

Comments

@CloudyDory
Copy link
Contributor

CloudyDory commented Jan 31, 2024

For a model written to process single input data, is it possible to convert the model to process batched input data simply by using jax.vmap? Or do we have to re-write the model to process batched data?

The code section looks like this:

# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())

def step_run(i, x_single):
    '''
    Inputs:
        x_single: [height, width]
    '''
    x = bm.where(bm.logical_and(cfg['stim_start_timepoint']<=i, i<cfg['stim_end_timepoint']), x_single, blank_img)
    out = model.step_run(i, x)  # [n_neuron]
    return out

def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [height, width]
        y_single: [1]
    '''
    model.reset_state() 
    indices = np.arange(cfg['total_timepoint'])  # sequence length
    spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), indices)  # [length, n_neuron]
    frate_out = bm.sum(spike_out, axis=0) + 1.0e-6  # [n_neuron]
    
    predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0)  # log-prababilities, [batch=1, n_neuron]
    loss = bp.losses.nll_loss(-predicts, y_single)  # scalar, Need to manually add a negative sign because BrainPy does not do so. scalar
    acc = bm.mean(predicts.argmax(-1) == y_single)  # scalar
    return loss, acc

grad_f = jax.vmap(bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True))

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, height, width]
        y_batch: [batch, 1]
    '''
    train_vars = model.train_vars().unique()

    grads, losses, acces = grad_f(x_batch, y_batch)  # PyTree of gradients, [batch], [batch]
    grads_mean = jax.tree_map(lambda x: bm.sum(x, axis=0), grads)
    
    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    opt.update(grads_mean)
    
    return loss, acc

It currently raises the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[50000] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

I found a previous issue (#206) mentioning this. Is it still not possible to use jax.vmap with brainpy models?

@chaoming0625
Copy link
Collaborator

Thanks for opening this great question. Actually, the object-oriented style in BrainPy does not support a general mapping transformation with vmap and pmap. But we can easily customize our mapping for a specific problem. Here i will give you an example.

@chaoming0625
Copy link
Collaborator

chaoming0625 commented Feb 1, 2024

The key of BrainPy's Variable system is to find out all variables used in the objects and then transform this object into a function so that it can be compiled by JAX's functional transformations. Existing brainpy transformations like brainpy.math.jit, brainpy.math.scan have already hidden these processes. However, for a new transformation, users can also follow such two steps.

In your case, you want to vmap the gradient function to get the batched gradients. So, all weights can not be batched, all states or variables should be batched, and the outputs should also be batched. Therefore, we can customize this transformation as:

import jax

import brainpy.math as bm
from functools import wraps


def vmap_grad_fun(f, *inputs):
  # Step 1: finding out all variables #
  # --------------------------------- #

  # evaluation without spending any actual FLOP computation
  vars, _ = bm.eval_shape(f, *inputs)

  # separate variables into two groups: weights and states
  weights, states = vars.separate_by_instance(bm.TrainVar)

  # Step 2: transform the object as the function that compatible with jax.vmap #
  # -------------------------------------------------------------------------- #
   
  @wraps(f)
  def new_fun(ws, vars, inputs):
    # A. assign weights and states in each batch to the model
    for key in ws: weights[key] = ws[key]
    for key in vars: states[key] = vars[key]

    # B. run the function
    outputs = f(*inputs)

    # C. return outputs of each batch
    return outputs

  ori_weights, ori_states = weights.dict_data(), vars.dict_data()
  # replicate the states for batching
  batch_size = inputs[0].shape[0]
  batched_states = jax.tree_map(lambda x: bm.repeat(bm.expand_dims(x, 0), batch_size, axis=0), ori_states)

  # batching the states and inputs
  batched_outs = jax.vmap(new_fun, in_axes=(None, 0, 0), out_axes=0)(ori_weights, batched_states, inputs)
  del batched_states

  # recovery the origin weights and states
  for key in ori_weights: weights[key] = ori_weights[key]
  for key in ori_states: vars[key] = ori_states[key]

  # Step 3: return the batched outputs
  return batched_outs

I hope this example can help you achieve the desired transformation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants