-
Notifications
You must be signed in to change notification settings - Fork 1
/
_playground.py
61 lines (45 loc) · 1.63 KB
/
_playground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
'''
Dirty Experiments
nothing really meainingful here
'''
from sympy import *
from sympy.parsing.sympy_parser import parse_expr, standard_transformations, implicit_multiplication
import random
from random import randint as ri
import torch
x, y, z= symbols('x y z')
sym = x, y, y #x, y, z in case of three variables... or change the parameters below
def random_function(n = 5, deg = 2, cr = 1, case_add = 6):
f = 1
for _ in range(n):
if ri(0, case_add):
f = f + ri(-cr,cr) * sym[ri(0, 2)]**ri(1, deg)
else:
f = (f)*(ri(1, cr)*sym[ri(0, 2)]**ri(1, deg))
return expand(f)
model = torch.load("model.dat")
model.device = torch.device("cuda:0")
cnt = 0
tot = 200
for _ in range(tot):
f1, f2 = random_function(n = 3), random_function(n = 2)
f3 = f1**2 + f2**2
f4 = expand(f3)
src = str(f4).replace(' ', '').replace('**', '^')
tgt = str(f3).replace(' ', '').replace('**', '^')
print(f'Real : {tgt}')
for i in range(1, 11):
rets = model.toSOP(src)[0]
res = "False"
try:
trs = standard_transformations + (implicit_multiplication,)
ex = parse_expr(rets.replace('^', '**'), transformations = trs)
if expand(tgt) == expand(ex):
res = "True"
cnt += 1
print(f'Prediction {i} : {rets} : {res}')
break
except:
pass
print(f'Prediction {i} : {rets} : {res}')
print(f"{100 * cnt/tot :.2f} % are correct, with 10 possible trials.")