/
poly.py
83 lines (67 loc) · 2.57 KB
/
poly.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
""" PyTorchPoly """
import torch
from torch.autograd import Variable
import torch.nn as nn
# import pyindex
def legendre(x, degree):
retvar = torch.ones(x.size(0), degree+1).type(x.type())
# retvar[:, 0] = x * 0 + 1
if degree > 0:
retvar[:, 1] = x
for ii in range(1, degree):
retvar[:, ii+1] = ((2 * ii + 1) * x * retvar[:, ii] - \
ii * retvar[:, ii-1]) / (ii + 1)
return retvar
def chebyshev(x, degree):
retvar = torch.zeros(x.size(0), degree+1).type(x.type())
retvar[:, 0] = x * 0 + 1
if degree > 0:
retvar[:, 1] = x
for ii in range(1, degree):
retvar[:, ii+1] = 2 * x * retvar[:, ii] - retvar[:, ii-1]
return retvar
def hermite(x, degree):
retvar = torch.zeros(x.size(0), degree+1).type(x.type())
retvar[:, 0] = x * 0 + 1
if degree > 0:
retvar[:, 1] = x
for ii in range(1, degree):
retvar[:, ii+1] = x * retvar[:, ii] - retvar[:, ii-1] / ii
return retvar
class UnivariatePoly(nn.Module):
""" Univariate Legendre Polynomial """
def __init__(self, PolyDegree, poly_type):
super(UnivariatePoly, self).__init__()
self.degree = PolyDegree
self.linear = nn.Linear(PolyDegree+1, 1, bias=False)
self.poly_type = poly_type
def forward(self, x):
if self.poly_type == "legendre":
vand = legendre(x, self.degree)
elif self.poly_type == "chebyshev":
vand = chebyshev(x, self.degree)
elif self.poly_type == "hermite":
vand = hermite(x, self.degree)
else:
print("No Polynomial type ", self.poly_type, " is implemented")
exit(1)
# print("vand = ", vand)
retvar = self.linear(vand)
return retvar
# class MultiLegendre(nn.Module):
# def __init__(self, dim, order):
# super(MultiLegendre, self).__init__()
# self.order = order
# self.dim = dim
# self.num = pyindex.TotalOrder(self.dim, self.order).get_count()
# self.polys = nn.ModuleList([Legendre(order) for i in range(dim)])
# def forward(self, x):
# vs = [p(x[:, ii]) for ii, p in enumerate(self.polys)]
# v = Variable(torch.ones((x.size(0), self.num)).type(x.data.type()))
# iterator = pyindex.TotalOrder(self.dim, self.order)
# for ii, inds in enumerate(iterator):
# for jj, ind in enumerate(inds):
# v[:,ii] = v[:, ii] * vs[jj][:, ind]
# return v
# def num_parameters(self):
# return self.num