Skip to content

Commit

Permalink
solve device problem
Browse files Browse the repository at this point in the history
  • Loading branch information
BLUE-coconut committed Aug 16, 2023
1 parent 77661b1 commit 0b363ab
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions mmseg/models/losses/pixel_contrast_cross_entropy_loss.py
Expand Up @@ -37,10 +37,14 @@ def hard_anchor_sampling(X, y_hat, y, ignore_index, max_views, max_samples):

n_view = max_samples // total_classes
n_view = min(n_view, max_views)

X_ = torch.zeros((total_classes, n_view, feat_dim),
dtype=torch.float).cuda()
y_ = torch.zeros(total_classes, dtype=torch.float).cuda()
if(torch.cuda.is_available()):
X_ = torch.zeros((total_classes, n_view, feat_dim),
dtype=torch.float).cuda()
y_ = torch.zeros(total_classes, dtype=torch.float).cuda()
else:
X_ = torch.zeros((total_classes, n_view, feat_dim),
dtype=torch.float)
y_ = torch.zeros(total_classes, dtype=torch.float)

X_ptr = 0
for ii in range(batch_size):
Expand Down Expand Up @@ -95,7 +99,10 @@ def contrastive(embed, label, temperature, base_temperature):
anchor_num, n_view = embed.shape[0], embed.shape[1]

label = label.reshape((-1, 1))
mask = torch.eq(label, label.permute([1, 0])).float().cuda()
if(torch.cuda.is_available()):
mask = torch.eq(label, label.permute([1, 0])).float().cuda()
else:
mask = torch.eq(label, label.permute([1, 0])).float()

contrast_count = n_view
contrast_feature = torch.concat(torch.unbind(embed, dim=1), dim=0)
Expand All @@ -112,9 +119,12 @@ def contrastive(embed, label, temperature, base_temperature):
mask = torch.tile(mask, [anchor_count, contrast_count])
neg_mask = 1 - mask

logits_mask = torch.ones_like(mask).scatter_(
1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(), 0)
if(torch.cuda.is_available()):
logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1).cuda(), 0)
else:
logits_mask = torch.ones_like(mask).scatter_(1,
torch.arange(anchor_num * anchor_count).view(-1, 1), 0)

mask = mask * logits_mask

Expand Down

0 comments on commit 0b363ab

Please sign in to comment.