Skip to content

online hard-example mining/examining under Multi-GPU ='dp' #1170

Discussion options

You must be logged in to vote

have you considered using a library such as pytorch-metric-learning?

in general, it would look something like

class MinerNetwork(pl.LightningModule):
  def __init__(...):
    self.network = # define network here
    self.miner_function = miners.DistanceWeightedMiner()
    self.objective = losses.TripletMarginLoss()

  def forward(self, data, labels):
    embeddings = self.network(data)
    return embeddings

  def training_step(self, batch, batch_idx):
    data, labels = batch
    embeddings = self(data)
    pairs = self.miner_function(embeddings, labels)
    loss = self.objective(embeddings, labels, pairs)
    return loss

this does mining within each batch that you pass in. i'm not sure…

Replies: 3 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by Borda
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #1170 on December 23, 2020 19:32.