Skip to content

BorealisAI/ConR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ConR: Contrastive Regularizer for Deep Imbalanced Regression

This repository contains the code for the ICLR 2024 paper: ConR: Contrastive Regularizer for Deep Imbalanced Regression
Mahsa Keramati, Lili Meng, R David Evans


ConR key insights. a) Without ConR, it is common to have minority examples mixed with majority examples. b) ConR adds additional loss weight for minority, and mis-labelled examples, resulting in better feature representations and c) better prediction error.

Quick Preview

ConR is complementary to conventional imbalanced learning techniques. The following code snippent shows the implementation of ConR for the task of Age estimation

def ConR(features, targets, preds, w=1, weights=1, t=0.07, e=0.01):
    q = torch.nn.functional.normalize(features, dim=1)
    k = torch.nn.functional.normalize(features, dim=1)

    l_k = targets.flatten()[None, :]
    l_q = targets

    p_k = preds.flatten()[None, :]
    p_q = preds

    l_dist = torch.abs(l_q - l_k)
    p_dist = torch.abs(p_q - p_k)

    pos_i = l_dist.le(w)
    neg_i = ((~ (l_dist.le(w))) * (p_dist.le(w)))

    for i in range(pos_i.shape[0]):
        pos_i[i][i] = 0

    prod = torch.einsum("nc,kc->nk", [q, k]) / t
    pos = prod * pos_i
    neg = prod * neg_i

    pushing_w = weights * torch.exp(l_dist * e)
    neg_exp_dot = (pushing_w * (torch.exp(neg)) * neg_i).sum(1)

    # For each query sample, if there is no negative pair, zero-out the loss.
    no_neg_flag = (neg_i).sum(1).bool()

    # Loss = sum over all samples in the batch (sum over (positive dot product/(negative dot product+positive dot product)))
    denom = pos_i.sum(1)

    loss = ((-torch.log(
        torch.div(torch.exp(pos), (torch.exp(pos).sum(1) + neg_exp_dot).unsqueeze(-1))) * (
                 pos_i)).sum(1) / denom)

    loss = (weights * (loss * no_neg_flag).unsqueeze(-1)).mean()

    return loss

Usage

Please go into the sub-folder to run experiments for different datasets.

Acknowledgment

The code is based on Yang et al., Delving into Deep Imbalanced Regression, ICML 2021 and Ren et al.,Balanced MSE for Imbalanced Visual Regression, CVPR 2022.

Releases

No releases published

Packages

No packages published

Languages