-
Hello, |
Beta Was this translation helpful? Give feedback.
Answered by
mberr
Dec 8, 2023
Replies: 1 comment 1 reply
-
Yes, it is 🙂 The following is a small snippet to show that: import torch
from pykeen.datasets import get_dataset
from pykeen.models import TransE
from pykeen.nn.init import PretrainedInitializer
from pykeen.pipeline import pipeline
dataset = get_dataset(dataset="nations")
# assume that this is your pre-trained entity embeddings
pre_trained = torch.rand(dataset.num_entities, 32)
# create model instance
model = TransE(
triples_factory=dataset.training,
embedding_dim=pre_trained.shape[-1],
entity_constrainer=None, # by default, TransE would normalize entity embedding to have unit norm
entity_initializer=PretrainedInitializer(pre_trained),
)
# set entity representations to non-trainable
model.entity_representations.requires_grad_(False)
# now just train as usual
result = pipeline(dataset=dataset, model=model)
# verify that entity embeddings stayed the same
assert torch.allclose(pre_trained, model.entity_representations[0]()) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
pkociepka
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yes, it is 🙂
The following is a small snippet to show that: