-
Notifications
You must be signed in to change notification settings - Fork 18
/
myNetwork.py
27 lines (21 loc) · 820 Bytes
/
myNetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch.nn as nn
import torch
class network(nn.Module):
def __init__(self, numclass, feature_extractor):
super(network, self).__init__()
self.feature = feature_extractor
self.fc = nn.Linear(512, numclass, bias=True)
def forward(self, input):
x = self.feature(input)
x = self.fc(x)
return x
def Incremental_learning(self, numclass):
weight = self.fc.weight.data
bias = self.fc.bias.data
in_feature = self.fc.in_features
out_feature = self.fc.out_features
self.fc = nn.Linear(in_feature, numclass, bias=True)
self.fc.weight.data[:out_feature] = weight[:out_feature]
self.fc.bias.data[:out_feature] = bias[:out_feature]
def feature_extractor(self,inputs):
return self.feature(inputs)