Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QuatE: GPU memory is not released per epoch #1351

Open
3 tasks done
LuisaWerner opened this issue Dec 8, 2023 · 3 comments
Open
3 tasks done

QuatE: GPU memory is not released per epoch #1351

LuisaWerner opened this issue Dec 8, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@LuisaWerner
Copy link
Contributor

Describe the bug

Hi,
I am training the KGE Model QuatE on a cuda device and I am running into a Cuda Out of Memory Error after a few epochs.

I have looked at the allocated memory at various points of the training loop. The allocated cuda memory increases with each training batch and also with each epoch, so that CUDA OOM occurs after a certain number of epochs.
Here is a graphic that visualises the problem.

Screenshot 2023-12-08 at 11 43 12

I have also tested other KGE methods (BoxE, TransE, CrossE, ConvKB, RGCN, NTN) with the same code and have not found such problems with any of them. With them, the allocated memory remains constant per batch and epoch.

Do you have any hints where the problem comes from and how to fix it? I took a closer look at the QuatEInteraction and realised that a buffer called table is created. Could the problem perhaps lie here?

class QuatEInteraction(
   FunctionalInteraction[
       torch.FloatTensor,
       torch.FloatTensor,
       torch.FloatTensor,
   ],
):
   """A module wrapper for the QuatE interaction function.

   .. seealso:: :func:`pykeen.nn.functional.quat_e_interaction`
   """

   # with k=4
   entity_shape: Sequence[str] = ("dk",)
   relation_shape: Sequence[str] = ("dk",)
   func = pkf.quat_e_interaction

   def __init__(self) -> None:
       """Initialize the interaction module."""
       super().__init__()
       self.register_buffer(name="table", tensor=quaterion_multiplication_table())

How to reproduce

The code to instantiate the model looks like this:

model = QuatE(triples_factory=triples_factory, loss=loss, random_seed=conf.seed)

The code that I use for one epoch looks like this

    def train_step(self, triples: TriplesFactory) -> tuple[float, float]:
        """ One epoch on the training set """
        start = time()
        train_data_loader = TripleDataLoader(triples=TripleDataset(triples), batch_size=self.batch_size)
        self.model.train()
        train_loss = 0

        for i_batch, batch in enumerate(train_data_loader):
            self.optimizer.zero_grad()
            batch = batch.to(self.device)
            self.model.to(self.device)
            pred_scores = self.model.score_hrt(batch.triples)
            batch_loss = self.model.loss(scores=pred_scores, labels=batch.labels)
            batch_loss.backward()
            self.optimizer.step()
            train_loss += batch_loss.detach().cpu().item()

The batch_size is 256.

Environment

GPU Quadro P5000
Python 3.11.0
torch 2.1.1
pykeen 1.10.1

Additional information

No response

Issue Template Checks

  • This is not a feature request (use a different issue template if it is)
  • This is not a question (use the discussions forum instead)
  • I've read the text explaining why including environment information is important and understand if I omit this information that my issue will be dismissed
@LuisaWerner LuisaWerner added the bug Something isn't working label Dec 8, 2023
@mberr
Copy link
Member

mberr commented Dec 8, 2023

Do you have any hints where the problem comes from and how to fix it? I took a closer look at the QuatEInteraction and realised that a buffer called table is created. Could the problem perhaps lie here?

I think this buffer should be unrelated (it is created only once, and pretty small, too).

How does self.model look like?

The default setting of QuatE (the full model configuration, not the interaction) uses a regularizer, and it seems as if you use a custom training loop, so my best guess would be that the regularization term keeps accumulating without being back-propagated; in this case, torch would not be able to release tensors from previous batches.

@LuisaWerner
Copy link
Contributor Author

Thanks for your answer and sorry for my late reply!

I use the default QuatE model from the pykeen library here

It looks like this

QuatE(
  (loss): BCEWithLogitsLoss()
  (interaction): QuatEInteraction()
  (entity_representations): ModuleList(
    (0): Embedding(
      (regularizer): LpRegularizer()
      (_embeddings): Embedding(3007, 400)
    )
  )
  (relation_representations): ModuleList(
    (0): Embedding(
      (regularizer): LpRegularizer()
      (_embeddings): Embedding(12, 400)
    )
  )
  (weight_regularizers): ModuleList()
)

Are you talking about the weight_regularizers or the LpRegularizer()?

@mberr
Copy link
Member

mberr commented Jan 8, 2024

I was talking about the two LpRegularizer instances shown as entity_representations[0].regularizer and relation_representations[0].regularizer; weight_regularizers is just an empty list 🙂

You can either

  • disable the two regularizers by providing
model = QuatE(
    ...,
    entity_regularizer=None,
    relation_regularizer=None,
)
  • or collect the regularization term by calling model.collect_regularization_term; you can either ignore this term (but in this case I would suggest to use the previous option instead), or add this to the loss before calculating the gradients.

As background info:

  • weight_regularizers are regularizers that calculate a regularization term on all weights (of soem tensor, e.g., the relation embedding matrix);
  • in contrast, the LpRegularizers you see above only calculate terms from the "activated" embeddings, i.e., the rows of the embedding matrix which are used in the current batch (and thus may also receive a non-zero gradient).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants