Is there an interface to get entity / relation IDs in a custom Interaction? #1385
Replies: 4 comments 1 reply
-
Hi @Jeffrey-Sardina , ordered by increasing effort, I can see the following options:
import torch
from pykeen.datasets import get_dataset
from pykeen.models.nbase import ERModel
from pykeen.nn.representation import Embedding, Representation
from pykeen.pipeline import pipeline
def get_degree_map(dataset):
# your method
triples = dataset.factory_dict["training"].mapped_triples
degrees = {}
for s, _, o in triples:
s, o = int(s), int(o)
if not s in degrees:
degrees[s] = 0
if not o in degrees:
degrees[o] = 0
degrees[s] += 1
degrees[o] += 1
return degrees
dataset = get_dataset(dataset="nations")
entity_to_degree = get_degree_map(dataset)
class ScaledRepresentationMixin(Representation):
# only for type annotation
scaling_factor: torch.Tensor
def __init__(self, *args, scaling_factor: torch.Tensor, **kwargs):
super().__init__(*args, **kwargs)
self.register_buffer(name="scaling_factor", tensor=scaling_factor)
def _plain_forward(self, indices: torch.LongTensor | None = None) -> torch.FloatTensor:
x = super()._plain_forward(indices=indices)
scale = self.scaling_factor if indices is None else self.scaling_factor[indices]
return scale * x
class ScaledEmbedding(ScaledRepresentationMixin, Embedding):
# note: the mixin needs to be first
pass
scaling_factor = (
torch.as_tensor([entity_to_degree.get(index, 0) for index in range(dataset.num_entities)]).float().view(-1, 1)
)
model = ERModel(
triples_factory=dataset.training,
interaction="TransE",
interaction_kwargs=dict(p=2),
entity_representations=ScaledEmbedding,
entity_representations_kwargs=dict(scaling_factor=scaling_factor, embedding_dim=32),
relation_representations_kwargs=dict(embedding_dim=32),
)
result = pipeline(dataset=dataset, model=model)
P.S.: You can find tools to count entities/relations/... also in |
Beta Was this translation helpful? Give feedback.
-
Wonderful, thank you, that helps a lot! The specific use case I am looking for is a bit more complex (the example above was just to illustrate the idea), so I will likely subclass ERModel for this. I'm happy to share the code once I get it working if that would be of use, as an example or the such? |
Beta Was this translation helpful? Give feedback.
-
Ok, I am almsot there! Basically, I think I have all thew right ideas here, with a minor blocker. I have extended the ERModel class as so:
I then run it with
This works great, except that my custom forward function in
The assertion never triggers. However, the foward function in |
Beta Was this translation helpful? Give feedback.
-
Ok, I got it! @mberr My solution is very hacky but it seems to work. Basically, you define your model and interaction as follows: import torch
from torch import nn
import pykeen
from pykeen.nn.modules import Interaction, parallel_unsqueeze
import torch.nn.functional as F
from pykeen.nn.representation import Representation
from typing import Optional, Sequence
from pykeen.nn import Embedding
class ID_Based_Interaction(Interaction):
'''
Note: this is very similar to ERMLP
(see: https://pykeen.readthedocs.io/en/stable/byo/interaction.html)
'''
def __init__(
self,
*args,
**kwargs
):
super().__init__()
# your code here...
def update(
self,
_get_representations=None
):
assert _get_representations, f'_get_representations must have a valud but is {_get_representations}'
self._get_representations = _get_representations
def forward(
self,
mode=None,
h_ids=None,
r_ids=None,
t_ids=None,
**kwargs
):
# get embeddings from the IDs
h, r, t = self._get_representations(
h=h_ids,
r=r_ids,
t=t_ids,
mode=mode
)
# compute a score using the embeddings
return -(h + r - t).norm(p=2, dim=-1)
def score_hrt(
self,
mode=None,
h_ids=None,
r_ids=None,
t_ids=None,
**kwargs
):
scores = self.forward(
mode=mode,
h_ids=h_ids,
r_ids=r_ids,
t_ids=t_ids,
**kwargs
)
scores = torch.unsqueeze(scores, 1)
return scores
def repeat_if_necessary(
scores: torch.FloatTensor,
representations: Sequence[Representation],
num: Optional[int],
) -> torch.FloatTensor:
'''
copised from https://pykeen.readthedocs.io/en/stable/_modules/pykeen/models/nbase.html#ERModel
'''
if representations:
return scores
return scores.repeat(1, num)
class ID_Based_Model(pykeen.models.multimodal.base.ERModel):
def __init__(
self,
dim: int = -1,
**kwargs,
) -> None:
super().__init__(
interaction=ID_Based_Interaction,
interaction_kwargs={
# ... your args here ...
},
entity_representations=Embedding,
entity_representations_kwargs=dict(
embedding_dim=dim,
),
relation_representations=Embedding,
relation_representations_kwargs=dict(
embedding_dim=dim,
),
**kwargs,
)
# we can't init the interaction with these since bc that would be a
# circuclar dependency, so instead we immediately pass them to an
# update function
self.interaction.update(
_get_representations=self._get_representations,
)
def forward(
self,
h_indices: torch.LongTensor,
r_indices: torch.LongTensor,
t_indices: torch.LongTensor,
slice_size: Optional[int] = None,
slice_dim: int = 0,
*,
mode = None
) -> torch.FloatTensor:
if not self.entity_representations or not self.relation_representations:
raise NotImplementedError("repeat scores not implemented for general case.")
return self.interaction.score(
mode=mode,
h_ids=h_indices,
r_ids=r_indices,
t_ids=t_indices,
slice_size=slice_size,
slice_dim=slice_dim,
)
def score_hrt(
self,
hrt_batch: torch.LongTensor,
*,
mode = None
) -> torch.FloatTensor:
return self.interaction.score_hrt(
mode=mode,
h_ids=hrt_batch[:, 0],
r_ids=hrt_batch[:, 1],
t_ids=hrt_batch[:, 2],
)
def score_h(
self,
rt_batch: torch.LongTensor,
*,
slice_size: Optional[int] = None,
mode = None,
heads: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# normalize before checking
if slice_size and slice_size >= self.num_entities:
slice_size = None
self._check_slicing(slice_size=slice_size)
# slice early to allow lazy computation of target representations
if slice_size:
return torch.cat(
[
self.score_h(
rt_batch=rt_batch,
slice_size=None,
mode=mode,
heads=torch.arange(start=start, end=min(start + slice_size, self.num_entities)),
)
for start in range(0, self.num_entities, slice_size)
],
dim=-1,
)
# add broadcast dimension
rt_batch = rt_batch.unsqueeze(dim=1)
# keep these as IDs, not embedding vectors
if not heads:
heads = torch.arange(start=0, end=self.num_entities)
# unsqueeze if necessary
if heads is None or heads.ndimension() == 1:
heads = parallel_unsqueeze(heads, dim=1)
h = heads
r = rt_batch[..., 0]
t = rt_batch[..., 1]
assert r.shape == t.shape
num_rt_pairs = r.shape[0]
num_heads = h.shape[0]
'''
Repeat interleave on h so all h's get mapped to all rt pairs
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
'''
h = h.repeat_interleave(num_rt_pairs, dim=0)
r = r.repeat(num_heads, 1)
t = t.repeat(num_heads, 1)
scores = repeat_if_necessary(
scores=self.interaction(
mode=mode,
h_ids=h,
r_ids=r,
t_ids=t
),
representations=self.entity_representations,
num=self._get_entity_len(mode=mode) if heads is None else heads.shape[-1],
)
scores = scores.reshape(-1, num_heads)
return scores
def score_r(
self,
ht_batch: torch.LongTensor,
*,
slice_size: Optional[int] = None,
mode = None,
relations: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# normalize before checking
if slice_size and slice_size >= self.num_relations:
slice_size = None
self._check_slicing(slice_size=slice_size)
# slice early to allow lazy computation of target representations
if slice_size:
return torch.cat(
[
self.score_r(
ht_batch=ht_batch,
slice_size=None,
mode=mode,
relations=torch.arange(start=start, end=min(start + slice_size, self.num_relations)),
)
for start in range(0, self.num_relations, slice_size)
],
dim=-1,
)
# add broadcast dimension
ht_batch = ht_batch.unsqueeze(dim=1)
# keep these as IDs, not embedding vectors
if not relations:
relations = torch.arange(start=0, end=self.num_relations)
# unsqueeze if necessary
if relations is None or relations.ndimension() == 1:
relations = parallel_unsqueeze(relations, dim=1)
h = ht_batch[..., 0]
r = relations
t = ht_batch[..., 1]
assert h.shape == t.shape
num_ht_pairs = h.shape[0]
num_relations = r.shape[0]
'''
Repeat interleave on r so all r's get mapped to all ht pairs
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
'''
h = h.repeat(num_relations, 1)
r = r.repeat_interleave(num_ht_pairs, dim=0)
t = t.repeat(num_relations, 1)
scores = repeat_if_necessary(
scores=self.interaction(
mode=mode,
h_ids=h,
r_ids=r,
t_ids=t
),
representations=self.relation_representations,
num=self.num_relations if relations is None else relations.shape[-1],
)
scores = scores.reshape(-1, num_relations)
return scores
def score_t(
self,
hr_batch: torch.LongTensor,
*,
slice_size: Optional[int] = None,
mode = None,
tails: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor: # noqa: D102
# normalize before checking
if slice_size and slice_size >= self.num_entities:
slice_size = None
self._check_slicing(slice_size=slice_size)
# slice early to allow lazy computation of target representations
if slice_size:
return torch.cat(
[
self.score_t(
hr_batch=hr_batch,
slice_size=None,
mode=mode,
tails=torch.arange(start=start, end=min(start + slice_size, self.num_entities)),
)
for start in range(0, self.num_entities, slice_size)
],
dim=-1,
)
# add broadcast dimension
hr_batch = hr_batch.unsqueeze(dim=1)
# keep these as IDs, not embedding vectors
if not tails:
tails = torch.arange(start=0, end=self.num_entities)
# unsqueeze if necessary
if tails is None or tails.ndimension() == 1:
tails = parallel_unsqueeze(tails, dim=1)
h = hr_batch[..., 0]
r = hr_batch[..., 1]
t = tails
assert h.shape == r.shape
num_hr_pairs = r.shape[0]
num_tails = t.shape[0]
'''
Repeat interleave on t so all t's get mapped to all hr pairs
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
'''
h = h.repeat(num_tails, 1)
r = r.repeat(num_tails, 1)
t = t.repeat_interleave(num_hr_pairs, dim=0)
scores = repeat_if_necessary(
scores=self.interaction(
mode=mode,
h_ids=h,
r_ids=r,
t_ids=t
),
representations=self.entity_representations,
num=self._get_entity_len(mode=mode) if tails is None else tails.shape[-1],
)
scores = scores.reshape(-1, num_tails)
return scores You can run it as follows from pykeen.pipeline import pipeline
from id_based_model import ID_Based_Model
def main():
pipeline_result = pipeline(
dataset='UMLS',
model=ID_Based_Model,
model_kwargs={
'dim': 100
},
negative_sampler_kwargs = {
'num_negs_per_pos': 5
},
training_kwargs=dict(
num_epochs=150,
batch_size=256
)
)
mr = pipeline_result.get_metric('mr')
mrr = pipeline_result.get_metric('mrr')
h1 = pipeline_result.get_metric('Hits@1')
h3 = pipeline_result.get_metric('Hits@3')
h5 = pipeline_result.get_metric('Hits@5')
h10 = pipeline_result.get_metric('Hits@10')
print(f'\nMR = {mr} \nMRR = {mrr} \nHits@(1,3,5,10) = {h1, h3, h5, h10}\n')
if __name__ == '__main__':
main() Please note that the As to why you might want to ahve ID based models and interactions -- well, that depends very much on your use case. For me, it's about being able to query arbitrary properties of a KG during the embedding process, some of which involves knowing the ID of specific triple elements (such as the degree of a node). I hope this helps the next person to come along looking for something like this! |
Beta Was this translation helpful? Give feedback.
-
TLDR: I know we can get an embedding from an ID. But can we get an entity / relation ID from an embedding?
Currently, in the docs, we can define a custom interaction:
Suppose, as a toy example, I want to weight each embedding by the degree of the node. (I'm not saying this is necessarily a good idea, it's just to illustrate the goal here). To do this, it's easy enough to calculate an
entity_to_degree
dictionary. For example, say we want this for theNations
dataset, we can doWe now try to incorporate this into our custom interaction:
However, that will cause an error, because PyKEEN passes embeddings, not entity / relation IDs, to the interaction. In order to actually make this work, we need some way to map an embedding to an ID; for example:
However, I cannot find any function that works like
get_id_of_emb
in the docs. Neither can I find a way to have PyKEEN pass IDs, rather than embeddings, to the interaction (which would be another solution here, and probably a better one).Is there any way to achieve this behaviour? If not, this would be a much appreciated feature!
Beta Was this translation helpful? Give feedback.
All reactions