Skip to content

Commit 5fbdb95

Browse files
author
Dimitri Coukos
committed
fix bug in dataset indexing
1 parent 22c8b4d commit 5fbdb95

File tree

5 files changed

+45
-10
lines changed

5 files changed

+45
-10
lines changed

.DS_Store

0 Bytes
Binary file not shown.

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,5 @@ Experiment 1:
1111
- Depth: {2, 4, 6, 8, 10, 12... while perf increasing}
1212
- Input Data: {Masif identifiers, Electrostatics, + Shape Index, +Rotated Positional Data}
1313
- Uses SeLU because ReLU kills the learning...
14+
15+
Observations & Results:

dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
File to generate the dataset from the ply files.
1313
1414
'''
15+
16+
1517
def convert_data(path_to_raw='./structures/', n=None, prefix='full'):
1618
'''Generate raw unprocessed torch file to generate pyg datasets with fewer
1719
candidates.
@@ -160,7 +162,7 @@ def read_ply(path, learn_iface=True):
160162

161163
x = ([torch.tensor(data['vertex'][axis]) for axis in ['charge', 'hbond', 'hphob']])
162164
x = torch.stack(x, dim=-1)
163-
y = None #what the fuck
165+
y = None
164166

165167
y = [torch.tensor(data['vertex']['iface'])]
166168
y = torch.stack(y, dim=-1)
@@ -254,16 +256,15 @@ def __init__(self, root='./datasets/{}/'.format(p.dataset), pre_transform=None,
254256
super(StructuresDataset, self).__init__(root, transform, pre_transform)
255257
self.has_nan = []
256258

257-
258259
@property
259260
def raw_file_names(self):
260-
n_files = len(glob('{}/raw/full_structure_*'.format(self.root, p.dataset)))
261+
n_files = len(glob('{}/raw/full_structure_*'.format(self.root)))
261262
return ['full_structure_{}.pt'.format(idx) for idx in range(0, n_files)]
262263

263264
@property
264265
def processed_file_names(self):
265-
n_files = len(glob('./datasets/{}/processed/data*'.format(p.dataset)))
266-
return ['data_0.pt'] # right order
266+
n_files = len(glob('{}/processed/data*'.format(self.root)))
267+
return ['data_{}.pt'.format(i) for i in range(0, n_files)] # right order
267268

268269
def download(self):
269270
pass
@@ -285,8 +286,8 @@ def process(self):
285286
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
286287
i += 1
287288

288-
def __len__(self):
289-
return len(self.processed_file_names)
289+
def len(self):
290+
return len(self.processed_paths)
290291

291292
def get(self, idx):
292293
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))

model_22.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
# ---- Training ----
8585

8686
for model_n, model in enumerate(models):
87+
8788
model.to(device)
8889
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate, weight_decay=p.weight_decay)
8990
# ------------ TRAINING NEW BLOCK --------------------------
@@ -94,11 +95,12 @@
9495
masked_loader = DataLoader(maskedset, shuffle=False, batch_size=p.test_batch_size)
9596

9697
data = next(iter(train_loader))
97-
ns = NeighborSampler(next(iter(train_loader)), 0.92, 9, batch_size=1000)
98+
ns = NeighborSampler(next(iter(train_loader)), 0.4, 9, batch_size=1)
9899

99100
# error with NeighborSampler:
100101
# neighbor sampler does not seem to be iterable like in the example.
101-
102+
for dataflow in ns():
103+
print(dataflow)
102104
model.train()
103105
first_batch_labels = torch.Tensor()
104106
pred = torch.Tensor()

playground.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,5 +478,35 @@ def test(mask):
478478
print('Epoch: {:02d}, Loss: {:.4f}, Test: {:.4f}'.format(
479479
epoch, loss, test_acc))
480480

481-
481+
# ---------------------- Trying to use datastructures ----------------------------
482+
import torch
482483
from dataset import StructuresDataset
484+
from transforms import *
485+
from torch_geometric.transforms import *
486+
from models import TwoConv
487+
import params as p
488+
489+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
490+
cpu = torch.device('cpu')
491+
# reproducibility
492+
torch.manual_seed(p.random_seed)
493+
np.random.seed(p.random_seed)
494+
learn_rate = p.learn_rate
495+
496+
497+
model = TwoConv(3, heads=p.heads).to(device)
498+
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate, weight_decay=p.weight_decay)
499+
500+
trainset = StructuresDataset(root='./datasets/full_train_ds/',
501+
pre_transform=Compose((FaceAttributes(), NodeCurvature(),
502+
FaceToEdge(), TwoHop())))
503+
504+
samples = len(trainset)
505+
cutoff = int(np.floor(samples*(1-p.validation_split)))
506+
train_indices = torch.tensor([i for i in range(0, cutoff)])
507+
train = trainset[train_indices]
508+
509+
validset = trainset[cutoff:]
510+
trainset = trainset[:cutoff]
511+
512+
sorted(glob.glob('./datasets/full_train_ds/processed/data_*.pt'))

0 commit comments

Comments
 (0)