Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BF16 matmul slower than F32 matmul on T4 GPU #12429

Open
sagelywizard opened this issue May 13, 2024 · 3 comments
Open

BF16 matmul slower than F32 matmul on T4 GPU #12429

sagelywizard opened this issue May 13, 2024 · 3 comments

Comments

@sagelywizard
Copy link

T4 GPU doesn't support BF16 matmul. Because of this, XLA switches BF16 matmul to F32 matmul on T4 (IIUC). This is obviously much slower, but it turns out it's actually slower than F32 matmul (i.e. BF16 appears to be less that 50% of the speed of F32). So, there must be something else going on here. If I understand correctly, "BF16" matmul should be the same performance as F32.

I also filed an issue against JAX, since that's where I discovered this issue. google/jax#21212

As I mentioned in the other issue, you can repro on a T4 Colab with the following code:

import jax
import jax.numpy as jnp
import timeit

def flops_calc(exponent=16, iters=10, dtype=jnp.float16):
  key = jax.random.PRNGKey(0)
  x_i = 2**exponent
  x_j = 4096
  y_j = 4096
  flop_count = x_i * x_j * y_j * 2
  x = jax.random.uniform(key, (x_i, x_j), dtype=dtype)
  y = jax.random.uniform(key, (x_j, y_j), dtype=dtype)
  matmul = jax.jit(lambda a, b: a @ b)
  matmul(x, y).block_until_ready()
  seconds_per_iter = timeit.timeit(lambda: matmul(x, y).block_until_ready(), number=iters) / iters
  flops = flop_count / seconds_per_iter
  return flop_count, flops

def flops_to_tflops(flops):
  return flops / 1e12

for dtype in [jnp.bfloat16, jnp.float16, jnp.float32]:
  print(dtype)
  for i in range(16):
    op_count, flops = flops_calc(exponent=i, dtype=dtype)
    print(f'Total TFLOP Count: {op_count / 1e12:.5f} | TFLOPS: {flops_to_tflops(flops):.2f}')
  print()
@cheshire
Copy link
Member

T4 GPU doesn't support BF16 matmul

It actually does, but it wouldn't use TensorCores and is incredibly slow

XLA switches BF16 matmul to F32 matmul on T4

This is a fairly recent change I did, you could try to find a commit with this. Without that change, matmuls are >4x slower from what I recall (depending on shape)

If I understand correctly, "BF16" matmul should be the same performance as F32.

Why would it? T4 has neither vector nor TensorCore support for BF16, so it has to emulate it, slowly.

Or do you mean on T4? On T4, you can look at the GPU profile.

Here the problem is we use Triton for fusions, which recently dropped support for pre-Ampere GPUs (or at least they aren't officially supported). Without fusions, we need to run an extra kernel to cast from BF16 to F32, which can be as expensive as the matmul itself.

@sagelywizard
Copy link
Author

Why would it?

Sorry, misspoke a bit. I meant that I'd expect that the emulation on T4 would be in the ballpark of (or at least not slower than) F32. But it sounds like it could be slower than F32 because of the extra cast?

@cheshire
Copy link
Member

Yes. Since we support cutlass fusions I might look into supporting that fusion (cast into matmul) via cutlass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants