Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] CUDA out of memory, strange numbers #2518

Open
sanaamouzahir opened this issue Apr 29, 2024 · 1 comment
Open

[Bug] CUDA out of memory, strange numbers #2518

sanaamouzahir opened this issue Apr 29, 2024 · 1 comment
Labels

Comments

@sanaamouzahir
Copy link

sanaamouzahir commented Apr 29, 2024

馃悰 Bug

I am using the sparse variational GPyTorch framework to perform 7500 tasks. I have 4800 data points, and I am using batch sizes (so both the inout and output matrices have dimension (4800,7500). Even with a btach size of 1 I get an memory allocation error. Where as, as I am using the variational framework, this should not be an issue. Also I was not having this issue with the RBF kernel, even when I was not using batches and I had the full dataset on the GPU.
train_dataset, test_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
kmeans = KMeans(n_clusters=200, random_state=0).fit(X_scaled)  # Assuming 'num_inducing_points' is defined as 10
inducing_points_centers = kmeans.cluster_centers_
# Model definition
class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_latents, num_tasks, inducing_points_centers):
        n_features = X_tensor.size(-1)
        inducing_points = torch.tensor(np.repeat(inducing_points_centers[np.newaxis, :, :], num_latents, axis=0), dtype=torch.float)
        
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(inducing_points.size(-2), batch_shape=torch.Size([num_latents]))
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
            gpytorch.variational.VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True),
            num_tasks=num_tasks, num_latents=num_latents, latent_dim=-1
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PeriodicKernel(batch_shape=torch.Size([num_latents])))
        #self.covar_module.base_kernel.period_length = 643
    def forward(self, x):
mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

model = MultitaskGPModel(30, y_tensor.size(-1), inducing_points_centers).to(device)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=y_tensor.size(-1)).to(device)
alpha=torch.tensor([0.1]).to(device)
optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
    {'params':[alpha]}], lr=0.01)
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=len(train_dataset))

def make_recursive_prediction(model, x, n_forward):
    """ Generate recursive predictions for each batch. """
    model.eval()
    likelihood.eval()
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        predictions = x
        for _ in range(n_forward):
            predictions = likelihood(model(predictions)).mean.detach()
    return predictions

model.train()
likelihood.train()
torch.cuda.empty_cache()
for epoch in tqdm.tqdm(range(100)):  # Adjust the number of epochs as necessary
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
    if epoch % 10 == 0:
        print(f'Epoch {epoch} - Loss: {loss.item()}')
    torch.cuda.empty_cache()
print("Training complete.")


        ```

** Stack trace/error message **

CUDA out of memory. Tried to allocate 172.00 MiB (GPU 5; 47.54 GiB total capacity; 5.74 GiB already allocated; 105.56 MiB free; 5.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
// Paste the bad output here!




 I do not have this memory allocation issue with other kernels, only with the periodic kernel. I think there is a bug somewhere, since with just a small batch size and with the variational framework, there should a priori not be any memory issue.



- <!-- GPyTorch Version (run `print(gpytorch.__version__)` -->1.11
- <!-- PyTorch Version (run `print(torch.__version__)` -->1.13.1
- <!-- Computer OS -->Linux

@m-julian
Copy link
Contributor

m-julian commented May 3, 2024

The periodic kernel makes an ... x d x m x n tensor (the ... are for batch dimensions, d is for dimensions of the input, m and n are for the number of points) so that might be causing the memory error. For the RBF kernel you get ... x m x n tensor only. Maybe try with less data and check what size of the covariance matrix you are getting?

diff = self.covar_dist(x1_, x2_, diag=diag, last_dim_is_batch=True, **params)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants