Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Half precision (float16 or bfloat16) support #539

Open
CloudyDory opened this issue Nov 8, 2023 · 10 comments
Open

Half precision (float16 or bfloat16) support #539

CloudyDory opened this issue Nov 8, 2023 · 10 comments
Labels
brainpy.dyn issue belongs to brainpy.dyn module brainpy.math issue belongs to brainpy.math module enhancement New feature or request

Comments

@CloudyDory
Copy link
Contributor

CloudyDory commented Nov 8, 2023

Does BrainPy fully support half-precision floating point numbers? I have tried to changed some of my own BrainPy code from using brainpy.math.float32 to brainpy.math.float16 or brainpy.math.bfloat16 (by explicitly setting the dtype of all variables and using a debugger to make sure that they won't be promoted to float32), but it seems that the GPU memory consumption and running speed is almost the same as using float32.

@chaoming0625
Copy link
Collaborator

chaoming0625 commented Nov 8, 2023

Great! This requirement needs to explicitly cast all parameters to brainpy.math.float_. For example, for a HH neuron model, its parameter gNa should be reinterpreted as gNa = bm.asarray(gNa, bm.float_). Ideally, users can set brainpy.math.set(float_=bm.float16), then all variables are running with float16 types.

@chaoming0625 chaoming0625 added enhancement New feature or request brainpy.math issue belongs to brainpy.math module brainpy.dyn issue belongs to brainpy.dyn module labels Nov 8, 2023
@chaoming0625
Copy link
Collaborator

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

@CloudyDory
Copy link
Contributor Author

One more thing that needs to be taken care of is that the coefficients of runge kutta methods should also be cast into brainpy.math.float_ type.

Could you let me know how to cast the runge kutta coefficients into brainpy.math.float_? It seems that the coefficients are automatically generated.

@chaoming0625
Copy link
Collaborator

yes, changes should be made in the brainpy framework. Note that dt should also be cast in the integrators.

@CloudyDory
Copy link
Contributor Author

Update: I think GPU memory consumption is mostly determined by JAX which preallocates 75% of the total GPU memory by default. This may be the reason why I don't see a reduction of memory consumption after switching to FP16.

@chaoming0625
Copy link
Collaborator

The preallocation can be disabled with the setting of brainpy.math.disable_gpu_memory_preallocation().

@CloudyDory
Copy link
Contributor Author

Hi, when running bm.set(float_=bm.bfloat16), I get a NotImplementedError. Is bfloat16 currently not supported in BrainPy?

@chaoming0625
Copy link
Collaborator

It is supported, but the set operation does not recognize it. Maybe we should customize the set() function.

@CloudyDory
Copy link
Contributor Author

I guess we should just add one more condition in the set_float function in brainpy/_src/math/environment.py?

def set_float(dtype: type):
  """Set global default float type.

  Parameters
  ----------
  dtype: type
    The float type.
  """
  if dtype in [jnp.float16, 'float16', 'f16']:
    defaults.__dict__['float_'] = jnp.float16
    defaults.__dict__['ti_float'] = ti.float16
  elif dtype in [jnp.float32, 'float32', 'f32']:
    defaults.__dict__['float_'] = jnp.float32
    defaults.__dict__['ti_float'] = ti.float32
  elif dtype in [jnp.float64, 'float64', 'f64']:
    defaults.__dict__['float_'] = jnp.float64
    defaults.__dict__['ti_float'] = ti.float64
  else:
    raise NotImplementedError

@chaoming0625
Copy link
Collaborator

Yes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
brainpy.dyn issue belongs to brainpy.dyn module brainpy.math issue belongs to brainpy.math module enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants