/
calculate_accuracy.py
84 lines (75 loc) · 2.95 KB
/
calculate_accuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from model import Model
import torch
from skimage import measure
import os
import numpy as np
from tqdm import tqdm
import wandb
api = wandb.Api()
def get_accs(level, task, strategy):
runs = get_runs(level, task, strategy)
if not len(runs) > 0:
return np.nan
accs = [predict(run, level, task) for run in runs]
return np.mean(accs)
def get_runs(level, task, strategy):
runs = api.runs("colin-cooke/ctc", {
"$and": [{"state": "finished"}, {"config.task": task},
{"config.level": level}, {"config.init_strategy": strategy}]})
ids = []
for run in tqdm(runs):
hist = run.history(pandas=False)
val_losses = [h['val_loss'] for h in hist]
if len(val_losses) != 75:
continue
ids.append(run.id)
return ids
def predict(run_id, level, task):
model_path = f'/hddraid5/data/colin/ctc/models/model_{run_id}.pth'
unet_path = f'/hddraid5/data/colin/ctc/models/net_0_{run_id}.pth'
model = Model(num_heads=1, batch_norm=True)
model_state = torch.load(model_path)
unet_state = torch.load(unet_path)
model.load_state_dict(model_state)
model.nets[0].load_state_dict(unet_state)
# for now we will only ever load test data
data_dir = '/hddraid5/data/colin/ctc/'
if task.lower() == 'hela':
test_data_path = os.path.join(data_dir, f'test_x_norm.npy')
elif task.lower() == 'pan':
test_data_path = os.path.join(data_dir, 'pan_test_x.npy')
pass
else:
raise RuntimeError
test_x = torch.from_numpy(np.load(test_data_path, mmap_mode='r'))
test_amnt = test_x.shape[0]
batch_size = 4
indices = np.arange(0, test_amnt, batch_size)
preds = []
bits = int(np.log2(level))
if task == 'pan':
y_true = np.load(f'/hddraid5/data/colin/ctc/pan_test_{bits}_y.npy')
else:
y_true = np.load(f'/hddraid5/data/colin/ctc/new_nuc_test_kb{bits}.npy')
for index in tqdm(indices):
test_x_batch = test_x[index:index + batch_size].float()
with torch.no_grad():
predictions = model(test_x_batch).cpu().numpy()
preds.append(predictions[0])
preds = np.concatenate(preds, axis=0)[:, 0]
# do rounding based on amount of bits
preds_rounded = np.round(preds * (level))
y_true_rounded = np.round(y_true * (level))
avg_accuracy = np.count_nonzero(preds_rounded == y_true_rounded) / np.prod(preds_rounded.shape)
return avg_accuracy
if __name__ == "__main__":
levels = [2, 4, 8, 16, 32, 64, 128]
tasks = ['pan']
strategies = ['dpc', 'learned', 'off_axis', 'all', 'random', 'center']
strategies = ['dpc', 'off_axis', 'all', 'random', 'center']
out_dir = os.path.join('/hddraid5/data/colin/ctc/accs')
os.makedirs(out_dir, exist_ok=True)
for task in tasks:
for strategy in strategies:
accs = np.array([get_accs(level, task, strategy) for level in levels])
np.save(os.path.join(out_dir, f'{task}_{strategy}.npy'), accs)