Skip to content

Commit 57fc227

Browse files
committed
WIP cnn
1 parent 638830b commit 57fc227

File tree

7 files changed

+123
-45
lines changed

7 files changed

+123
-45
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,4 @@ venv.bak/
116116

117117
# Other
118118
___*
119-
119+
resources/data/*

requirements.txt

2.26 KB
Binary file not shown.

src/classifier/GenerateDataSets.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,22 @@
77
if __name__ == '__main__':
88
resource_dir = pathlib.Path(__file__).parent.parent.parent / "resources" / "data"
99

10-
nb_models = 100
11-
for sub_folder in ["tests", "training", "validation"]:
12-
for i in range(0, nb_models):
13-
file_path = resource_dir / sub_folder / ("TrainingData_" + str(i) + ".json")
10+
''''for i in range(0, 700):
11+
file_path = resource_dir / "training" / ("TrainingData_" + str(i) + ".json")
1412
15-
nb_parts = random.randrange(25, 100)
16-
sampling = random.randrange(25, 100)
13+
nb_parts = random.randrange(50, 150)
14+
sampling = random.randrange(50, 150)
1715
18-
model = create_training_model(nb_parts=nb_parts, nb_points_per_mesh=sampling)
16+
model = create_training_model(nb_parts=nb_parts, nb_points_per_mesh=sampling)
1917
20-
model.save_model(str(file_path))
18+
model.save_model(str(file_path))'''''
19+
20+
for i in range(0, 300):
21+
file_path = resource_dir / "tests" / ("TrainingData_" + str(i) + ".json")
22+
23+
nb_parts = random.randrange(50, 150)
24+
sampling = random.randrange(50, 150)
25+
26+
model = create_training_model(nb_parts=nb_parts, nb_points_per_mesh=sampling)
27+
28+
model.save_model(str(file_path))

src/classifier/core/CNNModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from torch import Module
1+
from torch.nn import Module
22

33

44
class CNNModel(Module):

src/classifier/trainer/PipelineDataset.py

Lines changed: 0 additions & 35 deletions
This file was deleted.

src/classifier/trainer/Training.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os.path as osp
2+
import numpy as np
3+
4+
import wandb
5+
import torch
6+
import pathlib
7+
8+
from torch_geometric.data import Dataset, Data
9+
from torch_points3d.datasets.base_dataset import BaseDataset
10+
from torch_points3d.metrics.classification_tracker import ClassificationTracker
11+
12+
from src.classifier.trainer.TrainingModel import TrainingModel
13+
14+
from os import listdir
15+
16+
17+
class GlobalDataset(BaseDataset):
18+
19+
def __init__(self):
20+
super().__init__()
21+
22+
resource_dir = pathlib.Path(__file__).parent.parent.parent / "resources" / "data"
23+
24+
self.train_dataset = PipelineDataset(str(resource_dir / "training"), for_labels=True)
25+
self.test_dataset = PipelineDataset(str(resource_dir / "tests"), for_labels=False)
26+
27+
def get_tracker(self, wandb_log: bool, tensorboard_log: bool = False):
28+
return ClassificationTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log)
29+
30+
31+
class PipelineDataset(Dataset):
32+
33+
def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, for_labels=True):
34+
self._files = [f for f in listdir(osp.join(root, "raw")) if osp.isfile(osp.join(root, "raw", f))]
35+
self._for_labels = for_labels
36+
37+
super().__init__(root, transform, pre_transform, pre_filter)
38+
39+
@property
40+
def raw_file_names(self):
41+
return self._files
42+
43+
@property
44+
def processed_file_names(self):
45+
return [osp.splitext(f)[0] + "_label" + ".pt" if self._for_labels
46+
else osp.splitext(f)[0] + "_geometry" + ".pt" for f in self._files]
47+
48+
def process(self):
49+
idx = 0
50+
for raw_path in self.raw_paths:
51+
model = TrainingModel(json_file=raw_path).point_cloud_labelled
52+
53+
x = np.zeros((len(model.points), 3)) if self._for_labels else np.zeros((len(model.points), 4))
54+
y = np.zeros(len(model.points)) if self._for_labels else np.zeros((len(model.points), 6))
55+
56+
for i in range(0, len(model.points)):
57+
point = model.points[i]
58+
59+
x[i][0] = point.x
60+
x[i][1] = point.y
61+
x[i][2] = point.z
62+
63+
if not self._for_labels:
64+
x[i][3] = point.part_type.value
65+
66+
if self._for_labels:
67+
y[i] = point.part_type.value
68+
else:
69+
y[i][0] = point.center[0]
70+
y[i][1] = point.center[1]
71+
y[i][2] = point.center[2]
72+
73+
y[i][0] = point.direction[0]
74+
y[i][1] = point.direction[1]
75+
y[i][2] = point.direction[2]
76+
77+
data = Data(x=x, y=y)
78+
79+
if self.pre_filter is not None and not self.pre_filter(data):
80+
continue
81+
82+
if self.pre_transform is not None:
83+
data = self.pre_transform(data)
84+
85+
torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
86+
idx += 1
87+
88+
def len(self):
89+
return len(self.processed_file_names)
90+
91+
def get(self, idx):
92+
data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
93+
return data
94+
95+
96+
if __name__ == '__main__':
97+
wandb.init(project="PyPipes", entity="Kodvir")
98+
99+
wandb.config = {
100+
"learning_rate": 0.001,
101+
"epochs": 100,
102+
"batch_size": 128
103+
}

wandb/settings

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[default]
2+

0 commit comments

Comments
 (0)