This repository has been archived by the owner on Aug 27, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
embeddings.py
123 lines (107 loc) · 4.03 KB
/
embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter as P
from torchvision.models.inception import inception_v3
from models.rae import RAE_SN
class IdentityEmbedding:
def __call__(self, y):
return y
class OneHotEmbedding:
def __init__(self, num_classes):
self.num_classes = num_classes
def __call__(self, y):
# Expects to receive a 1D array of numbers
onehot = torch.eye(self.num_classes)[y].cuda()
return onehot
class InceptionEmbedding:
def __init__(self, parallel=False):
# Expects inputs to be in range [-1, 1]
inception_model = inception_v3(pretrained=True, transform_input=False)
inception_model = WrapInception(inception_model.eval()).cuda()
if parallel:
inception_model = nn.DataParallel(inception_model)
self.inception_model = inception_model
def __call__(self, x):
return self.inception_model(x)
# Wrapper for Inceptionv3, from Andrew Brock (modified slightly)
# https://github.com/ajbrock/BigGAN-PyTorch/blob/master/inception_utils.py
class WrapInception(nn.Module):
def __init__(self, net):
super(WrapInception, self).__init__()
self.net = net
self.mean = P(torch.tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1),
requires_grad=False)
self.std = P(torch.tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1),
requires_grad=False)
def forward(self, x):
# Normalize x
x = (x + 1.) / 2.0 # assume the input is normalized to [-1, 1], reset it to [0, 1]
x = (x - self.mean) / self.std
# Upsample if necessary
if x.shape[2] != 299 or x.shape[3] != 299:
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True)
# 299 x 299 x 3
x = self.net.Conv2d_1a_3x3(x)
# 149 x 149 x 32
x = self.net.Conv2d_2a_3x3(x)
# 147 x 147 x 32
x = self.net.Conv2d_2b_3x3(x)
# 147 x 147 x 64
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 73 x 73 x 64
x = self.net.Conv2d_3b_1x1(x)
# 73 x 73 x 80
x = self.net.Conv2d_4a_3x3(x)
# 71 x 71 x 192
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 35 x 35 x 192
x = self.net.Mixed_5b(x)
# 35 x 35 x 256
x = self.net.Mixed_5c(x)
# 35 x 35 x 288
x = self.net.Mixed_5d(x)
# 35 x 35 x 288
x = self.net.Mixed_6a(x)
# 17 x 17 x 768
x = self.net.Mixed_6b(x)
# 17 x 17 x 768
x = self.net.Mixed_6c(x)
# 17 x 17 x 768
x = self.net.Mixed_6d(x)
# 17 x 17 x 768
x = self.net.Mixed_6e(x)
# 17 x 17 x 768
# 17 x 17 x 768
x = self.net.Mixed_7a(x)
# 8 x 8 x 1280
x = self.net.Mixed_7b(x)
# 8 x 8 x 2048
x = self.net.Mixed_7c(x)
# 8 x 8 x 2048
pool = F.adaptive_avg_pool2d(x, 1).view(x.size(0), -1)
# 1 x 1 x 2048
return pool
class AutoencoderEmbedding(object):
def __init__(self, n_classes, nef, ndf, latent_dim, weights_path):
self.embedding = RAE_SN(n_classes,
nef=nef,
ndf=ndf,
latent_dim=latent_dim)
self.embedding = self.embedding.eval().cuda()
self.load_weights(weights_path)
def load_weights(self, weights_path):
if os.path.isfile(weights_path):
print("=> loading checkpoint '{}'".format(weights_path))
checkpoint = torch.load(weights_path)
self.embedding.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(weights_path, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(weights_path))
exit()
def __call__(self, x):
out = self.embedding.encode(x)
return out