Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: I can't release the memory in gpu during execute nt.empirical_ntk_fn #146

Open
kkeevin123456 opened this issue Mar 25, 2022 · 4 comments
Labels
question Further information is requested

Comments

@kkeevin123456
Copy link

I have follow the solution in RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm #8506
and High GPU memory during empirical NTK calculation #100

But it doesn't work for me.

This code can reproduce the problem:

import numpy as np
import cupy as cp
import torch
import torchvision.datasets as datasets
import torch.nn.functional as F

import jax
from jax import random
import jax.numpy as jnp
from jax.example_libraries import optimizers
from jax import jit, grad, vmap, pmap
import functools
import neural_tangents as nt
from neural_tangents import stax

from tqdm import tqdm
import gc

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

%env XLA_PYTHON_CLIENT_MEM_FRACTION=0.8
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
class_5_idx = np.where(mnist_trainset.targets.numpy() == 5)
class_3_idx = np.where(mnist_trainset.targets.numpy() == 3)
class_5 = mnist_trainset.data[class_5_idx]
class_3 = mnist_trainset.data[class_3_idx]
M = 200
W = H = 28
C = 1
P = 200
eta = 0.1

inputs = np.vstack((class_5[:P//2], class_3[:P//2]))
idx = np.arange(P)
np.random.shuffle(idx)
inputs = inputs[idx].reshape(-1, 28, 28, 1).astype('float32')

ys = np.vstack((*np.ones(P//2), *np.zeros(P//2)))
ys = ys[idx].astype('float32')

# kernel for all layers
init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Conv(10, (5, 5), (1, 1), 'SAME'), stax.Relu(),  # shape (output_c, filter_size, stride_size)
    stax.Flatten(),
    stax.Dense(M), stax.Relu(),
    stax.Dense(1)
)

# kernel for first layer
init_fn_a, apply_fn_a, kernel_fn_a = stax.serial(
    stax.Conv(10, (5, 5), (1, 1), 'SAME'), stax.Relu(),
    stax.Flatten()
)

shape, params = init_fn(random.PRNGKey(np.random.randint(1e6)), inputs.shape)
eNTK = nt.empirical_ntk_fn(apply_fn, vmap_axes=0, trace_axes=(), implementation=2)

opt_init, opt_update, get_params = optimizers.sgd(eta)
opt_state = opt_init(params)

_eNTK_a = nt.empirical_ntk_fn(apply_fn_a, vmap_axes=0, trace_axes=(), implementation=2)

eNTK_a = jit(lambda x1, x2, params: _eNTK_a(x1, x2, params))

loss = jit(lambda params, x, y: 0.5 * jnp.mean((apply_fn(params, x) - y) ** 2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))
nsteps = 1000

K_a = np.zeros((P, P))

for i in tqdm(range(nsteps)):
    opt_state = opt_update(i, grad_loss(opt_state, inputs, ys), opt_state)
    
    for j in range(P//2):
        for k in range(P//2):
            a = np.sum(eNTK_a(inputs[j:(j+1)*2], inputs[k:(k+1)*2], 
                                              [get_params(opt_state)[0], get_params(opt_state)[1], get_params(opt_state)[2]])
                                              , axis=(2, 3))
            K_a[j:(j+1)*2, k:(k+1)*2] = a
            print(mempool.used_bytes())              
            print(mempool.total_bytes())  
            print(pinned_mempool.n_free_blocks())
            a = cp.array(a)
            
            print(mempool.used_bytes())              
            print(mempool.total_bytes())  
            print(pinned_mempool.n_free_blocks())
            del a
            
            print(mempool.used_bytes())              
            print(mempool.total_bytes())
            print(pinned_mempool.n_free_blocks())
            
            mempool.free_all_blocks()
            pinned_mempool.free_all_blocks()
            gc.collect()
        break
    break

And after it run several iteration, the error happen.

image

@romanngg
Copy link
Contributor

romanngg commented Mar 25, 2022

IIUC every for every(j, k) your batch size is (j + 2, k + 2), so it grows every step, hence there may not be a memory leak, but rather each step you perform a larger and larger computation. Could you double check that inputs[j:(j+1)*2], inputs[k:(k+1)*2] is the correct indices you want to compute?

On a separate note, for CNNs you may want to try implementation=1 - see https://openreview.net/pdf?id=ym68T6OoO6L / https://www.youtube.com/watch?v=8MWOhYg89fY&t=10984s for details on why this could be better (eventually, implementation=3 could also be helpful, which we'll add to NT soon).

Lmk if this helps!

@romanngg romanngg added the question Further information is requested label Mar 25, 2022
@kkeevin123456 kkeevin123456 changed the title Question: I can release the memory in gpu during execute nt.empirical_ntk_fn Question: I can't release the memory in gpu during execute nt.empirical_ntk_fn Mar 25, 2022
@kkeevin123456
Copy link
Author

Thanks!!
It works!!

@kkeevin123456
Copy link
Author

Hi, there is another question.

If I want to parallel the above code on multi-gpu how can I do?
image

It seems only works on single GPU

@romanngg
Copy link
Contributor

I suggest the nt.batch decorator - https://neural-tangents.readthedocs.io/en/latest/batching.html, it allows to compute the kernel in batches utilizing all devices. Note that there will be a requirement that the batch size times the number of devices divides the total inputs batch size. Under the hood it uses jax.pmap (https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html if you need to write something more flexible).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants