-
Notifications
You must be signed in to change notification settings - Fork 6
/
run_compression.py
36 lines (25 loc) · 903 Bytes
/
run_compression.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed May 12 14:02:47 2021
@author: ben
"""
from continuous_net.tools.data_transform import DataTransform
from continuous_net.convergence import ConvergenceTester
import datasets
import glob
import os
DIR = "../stateful_cifar10/"
paths = glob.glob(f"{DIR}/*")
torch_train_data, torch_validation_data, torch_test_data = (datasets.get_dataset(name='CIFAR10', root='../'))
train_data = DataTransform(torch_train_data)
validation_data = DataTransform(torch_validation_data)
test_data = DataTransform(torch_test_data)
for path in paths:
try:
print(path)
ct = ConvergenceTester(path)
print(ct.eval_model)
ct.perform_project_and_infer(test_data, bases=["piecewise_constant"], n_bases=[16, 8], schemes=['Euler'], n_steps=[16,8])
except Exception as e:
print("Error wih ", path, ": ", e)