Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multiple GPU devices simulation and training of one dynamic system in brainpy #641

Open
Dr-Chen-Xiaoyu opened this issue Mar 2, 2024 · 15 comments
Labels
enhancement New feature or request

Comments

@Dr-Chen-Xiaoyu
Copy link

Hi, Chaoming:

I am trying to do simulation and training of a dynamic system (a self customized RNN based on brainpy, https://github.com/Dr-Chen-Xiaoyu/DecoModel) with very huge dimension and time steps. The memory usage is out of one single GPU device.

I believe this could be solved by running brainpy on multiple GPU devices with its own sharding method, just like jax's sharding or pytorch's torch.nn.DataParallel. A simplified case of RNN training is provided below, and change the dimension of RNN to very huge (maybe >1000) as well as the input output tensor (maybe >1000^3). Maybe you could modify this code with brainpy's sharding and make it as part of brainpy's tutorial if this is a general demand of users.

best,
Xiaoyu

The example code:

# %%
import os,jax
import numpy as np
import matplotlib.pyplot as plt

import brainpy as bp
import brainpy.math as bm
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)

print('bp version:', bp.__version__)
print(jax.local_devices())
#bp version: 2.4.6.post5
#[cuda(id=0), cuda(id=1)]

# %%
class RNN(bp.DynamicalSystemNS):
    def __init__(self, num_in, num_hid, num_out, batch_size=1):
        super(RNN, self).__init__()

        bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))

        # define parameters
        self.num_in  = num_in
        self.num_hid = num_hid
        self.num_out = num_out

        # define variables
        self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)

        # define weights
        self.win  = bm.TrainVar(bm.random.normal(0., 1., size=(num_in,  num_hid)))
        self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
        self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))

    def reset_state(self, batch_size):# this function defines how to reset the mode states
        self.state.value = bm.zeros((batch_size, self.num_hid))

    def update(self, x):# this function defined how the model update its state and produce its output
        self.state.value = bm.tanh( bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec) )
        return bm.matmul(self.state, self.wout)

# initialize model
bm.random.seed(123)
dim_in =1
dim_hid=10
dim_out=1
batch_size=1
model = RNN(dim_in, dim_hid, dim_out , batch_size)

# %%
# generate some data
Nsample = 500
X_train = bm.random.normal(0.,1., size=(batch_size ,Nsample,dim_in)) #(Batch,Time,dim)
Y_train = bm.random.normal(10.,1., size=(batch_size, Nsample,dim_out))

def plot_model_predict(model,X_train,Y_train):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    Y_model = runner.run(inputs=X_train)

    plt.plot(X_train[0,:,:])
    plt.plot(Y_train[0,:,:])
    plt.plot(Y_model[0,:,:])
    plt.show()
plot_model_predict(model,X_train,Y_train)

# %%
# training
def loss_fun(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

grad_fun = bm.grad(loss_fun,grad_vars=model.train_vars().unique(),return_value=True)

opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())

@bm.jit
def train(xs, ys):
    grads, loss = grad_fun(xs, ys)
    opt.update(grads)
    return loss

losses=[]
for _ in range(1000):
    losses.append(train(X_train,Y_train))
    
plt.plot(losses);plt.show()
plot_model_predict(model,X_train,Y_train)
@Dr-Chen-Xiaoyu Dr-Chen-Xiaoyu added the enhancement New feature or request label Mar 2, 2024
@Dr-Chen-Xiaoyu
Copy link
Author

I think I might find the way to sharding bm.array based on JAX's tutorial https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html :


# %%
import jax
import jax.numpy as jnp

import os
import numpy as np


os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import brainpy as bp
import brainpy.math as bm
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
print('bp version:', bp.__version__)

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding
from jax.sharding import PartitionSpec as P

# %%
def get_sharding_details(sharded_data):
    # We can get detailed information for each shard
    print("="*75)
    for i, shard in enumerate(sharded_data.global_shards):
        print(f"Shard no: {i:>5}")
        print(f"Device: {str(shard.device):>32}")
        print(f"Data shape: {str(shard.data.shape):>8}")
        print(f"Data slices: {str(shard.index):>22}")
        print("="*75)

# %%
devices = mesh_utils.create_device_mesh((len(jax.local_devices()),))
print(f"Device Array: {devices}")

# Create a mesh from the device array
mesh = Mesh(devices, axis_names=("ax"))

# Define sharding with a partiton spec
sharding = NamedSharding(mesh, P("ax"))

print(mesh)

# %%
a = jnp.ones((1000,1000,3))
get_sharding_details(a)

print("\nafter sharding:\n")

# Shard the data
b = jax.device_put(a, sharding)
get_sharding_details(b)

# %%
c = bm.ones((1000,1000,3))
get_sharding_details(c.value)

print("\nafter sharding:\n")

# Shard the data
d = bm.sharding.partition_by_sharding(c, sharding)
get_sharding_details(d.value)

Maybe just sharding the input output bm.array tensor along the batch axis, and then let it automatically calculate on multi-GPUs ?
Just some thought 😊

@Dr-Chen-Xiaoyu
Copy link
Author

print is something like that before- and after-sharding array:

===========================================================================
Shard no:     0
Device:                           cuda:0
Data shape: (1000, 1000, 3)
Data slices: (slice(None, None, None), slice(None, None, None), slice(None, None, None))
===========================================================================

after sharding:

===========================================================================
Shard no:     0
Device:                           cuda:0
Data shape: (500, 1000, 3)
Data slices: (slice(0, 500, None), slice(None, None, None), slice(None, None, None))
===========================================================================
Shard no:     1
Device:                           cuda:1
Data shape: (500, 1000, 3)
Data slices: (slice(500, 1000, None), slice(None, None, None), slice(None, None, None))
===========================================================================

@chaoming0625
Copy link
Collaborator

Thanks for the question. Sorry for the slow response. I will check it later.

@Dr-Chen-Xiaoyu
Copy link
Author

Hi, chaoming @chaoming0625

Maybe this issue is a bit hard with too many engineering works to achieve. 🫡

I just have an idea about a quick and cheap solution of this issue. As to #663 , if any built-in or customized brainpy dynamical system class could be automatically transformed into Flax's RNN cell using bp.dnn.ToFlaxRNNCell(). Then, we could just do multi-GPU parallel training using Flax (https://flax.readthedocs.io/en/latest/guides/parallel_training/index.html). 🤖

best,
Xiaoyu Chen

@chaoming0625
Copy link
Collaborator

yes, the idea is simple. I will give you the solution soon.

@chaoming0625
Copy link
Collaborator

Here is my example of using multiple GPUs. I marked the key code by using the comment [KEY].

import os
import jax

import brainpy as bp
import brainpy.math as bm

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)

print('bp version:', bp.__version__)
print(jax.local_devices())


# bp version: 2.4.6.post5
# [cuda(id=0), cuda(id=1)]

# %%
class RNN(bp.DynamicalSystemNS):
  def __init__(self, num_in, num_hid, num_out, batch_size=1):
    super(RNN, self).__init__()

    bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))

    # define parameters
    self.num_in = num_in
    self.num_hid = num_hid
    self.num_out = num_out

    # define variables [KEY]
    self.state = bp.init.variable(bm.zeros, num_hid, batch_size, axis_names=['hidden'])
    # self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)

    # define weights [KEY]
    self.win = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_in, num_hid), axis_names=[None, 'hidden']))
    self.wrec = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_hid), axis_names=[None, 'hidden']))
    self.wout = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_out), axis_names=['hidden', None]))
    # self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid)))
    # self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
    # self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))

  def reset_state(self, batch_size):  # this function defines how to reset the mode states
    self.state.value = bp.init.variable_(bm.zeros, (self.num_hid,), batch_size, axis_names=['hidden'])

  def update(self, x):  # this function defined how the model update its state and produce its output
    self.state.value = bm.tanh(bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec))
    return bm.matmul(self.state, self.wout)


with bm.sharding.device_mesh(jax.devices(), ['hidden']):  # [KEY]
  # initialize model
  bm.random.seed(123)
  dim_in = 1
  dim_hid = 10
  dim_out = 1
  batch_size = 1
  model = RNN(dim_in, dim_hid, dim_out, batch_size)

  # %%
  # generate some data
  Nsample = 500
  X_train = bm.random.normal(0., 1., size=(batch_size, Nsample, dim_in))  # (Batch,Time,dim)
  Y_train = bm.random.normal(10., 1., size=(batch_size, Nsample, dim_out))


  # training
  def loss_fun(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss


  grad_fun = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), return_value=True)

  opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())


  @bm.jit
  def train(xs, ys):
    grads, loss = grad_fun(xs, ys)
    opt.update(grads)
    return loss

  losses = []
  for _ in range(1000):
    losses.append(train(X_train, Y_train))

@chaoming0625
Copy link
Collaborator

chaoming0625 commented Apr 14, 2024

The concept is very simple.

  1. initialize a context manager to setup a device mesh. Here
with bm.sharding.device_mesh(devices, ['hidden']):
   ...

means that the hidden dimension will be partitioned on the given devices.

Note that the devices should be the same dimension as the hidden. For example, if you want to partition the model onto two-dimensional devices by input and hidden, We should set up a context as:

with bm.sharding.device_mesh(np,asarray(jax.devices(), (2, 2)), ['input', 'hidden']):
   ...
  1. Initializing the variable of weights by using brainpy.init.variable_(...., axis_names=['input', 'hidden']). The data will be automatically partitioned on the devices if the given axis name matches the device mesh axis.

  2. using brainpy.math.jit. This is the key to the parralelization. All functions should have a jit decorator, otherwise, the model will not be parallelized according to the setting.

@chaoming0625
Copy link
Collaborator

Please tell me whether the above code works.

Please also see an example of TPU multi-device partition examples of COBA-HH network model.

@chaoming0625
Copy link
Collaborator

chaoming0625 commented Apr 14, 2024

By the way, I apologize for the very late response!

@Dr-Chen-Xiaoyu
Copy link
Author

It works! Thanks so much!🫰

For model without sharding:
image

After using sharding, the memory is shared by two GPU cards with 2x faster🫡:
image

@chaoming0625
Copy link
Collaborator

Thanks for the feedback!

@Dr-Chen-Xiaoyu
Copy link
Author

One more question about the details. it seems that you partition the model (the hidden states of this RNN) into two GPUs. Why not partition along the batch axis? it seems more nature for users.

@chaoming0625
Copy link
Collaborator

This is a good idea. While, if the batch size is the challenge hindering the training of the model on one GPU, we can decrease the batch size, rather than partition it on multiple devices. One more difficult situation is that the model is too big to install on one device. For such cases, we can partition the model on multiple devices. For example, simulating a very large-scale SNN model (usually there are no batch sizes).

@chaoming0625
Copy link
Collaborator

Partitioning on hidden states, and their interaction matrix is a simple model parallelization method.

@Dr-Chen-Xiaoyu
Copy link
Author

Okay, I see.

By the way, I found that in the code of model definition, only change one line about the model state variable is enough for parallelization. No need to change the weights TrainVar with axis_names=['input' ,'hidden'] things.

# define variables
self.state = bp.init.variable(bm.zeros, batch_size, num_hid,  axis_names=['hidden'], batch_axis_name=['batch']) #<<<关键点

# define weights
self.win  = bm.TrainVar(bm.random.normal(0., 1., size=(num_in,  num_hid))) # 不用改
self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out))) 

Thanks again for the help👍👍👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants