-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
113 lines (96 loc) · 4.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# encoding=utf-8
import argparse
import os
import random
import numpy as np
import torch
from models.PDGNet import PDGNet
from models.PDGNet_v2 import PDGNet_v2
def parse_args():
desc = "Pytorch PointGAN"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--phase', type=str, default='train', help='train or test ?')
parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--batch_size', type=int, default=50, help='input batch size [default: 30]')
parser.add_argument('--num_point', type=int, default=2048, help='Point Number [256/512/1024/2048] [default: 1024]')
parser.add_argument('--num_k',type=int, default=20,help = 'number of the knn graph point')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Initial learning rate [default: 0.0001]')
parser.add_argument('--max_epoch', type=int, default=300, help='number of epochs to train for')
parser.add_argument('--noise_dim', type=int, default=128, help='dimensional of noise')
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
parser.add_argument('--debug', type=bool, default = True, help='print log')
parser.add_argument('--data_root', default='/opt/data/private/shapenet/shapenet.hdf5', help='data root [default: xxx]')
parser.add_argument('--log_info', default='log_info.txt', help='log_info txt')
parser.add_argument('--model_dir', help='model dir [default: None, must input]')
parser.add_argument('--checkpoint_dir', default='checkpoint', help='Checkpoint dir [default: checkpoint]')
parser.add_argument('--snapshot', type=int, default=20, help='how many epochs to save model')
parser.add_argument('--choice', default=None, help='choice class')
parser.add_argument('--network', default=None, help='which network model to be used')
parser.add_argument('--savename',default = None,help='the generate data name')
parser.add_argument('--pretrain_model_G', default=None, help='use the pretrain model G')
parser.add_argument('--pretrain_model_D', default=None, help='use the pretrain model D')
parser.add_argument('--softmax', default='True', help='softmax for bilaterl interpolation')
parser.add_argument('--dataset', default='shapenet15k', help='choice dataset [shapenet15k, modelnet10, modelnet40]')
parser.add_argument('--normalize', type=str, default='shape_bbox', choices=[None, 'shape_unit', 'shape_bbox'])
parser.add_argument('--seed', type=int, default=9999)
parser.add_argument('--save_dir', type=str, default='./results')
parser.add_argument('--device', type=str, default='cuda')
return check_args(parser.parse_args())
"""
CUDA_VISIBLE_DEVICES=7 python main.py --network basecnn --choice Chair --snapshot 2 --model_dir basecnn_20190301
"""
def check_folder(dir):
if not os.path.exists(dir):
os.makedirs(dir)
return dir
def check_args(args):
if args.model_dir is None:
print('please create model dir')
exit()
if args.network is None:
print('please select model!!!')
exit()
check_folder(args.checkpoint_dir) # --checkpoint_dir
check_folder(os.path.join(args.checkpoint_dir, args.model_dir)) # --chekcpoint_dir + model_dir
try: # --epoch
assert args.max_epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
try: # --batch_size
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
def main():
# args
args = parse_args()
if args is None: exit()
args.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", args.manualSeed)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
# create model
print('****************network: {}****************'.format(args.network))
if args.network == 'PDGNet':
gan = PDGNet(args)
elif args.network == 'PDGNet_v2':
gan = PDGNet_v2(args)
else:
print('select model error!!!')
exit()
gan.build_model()
if args.phase == 'train' :
# cp mainly file to corresponding model dir
os.system('cp main.py %s' % (os.path.join(args.checkpoint_dir, args.model_dir))) # bkp of main.py
os.system('cp models/%s.py %s' % (args.network, os.path.join(args.checkpoint_dir, args.model_dir))) # bkp of model.py
gan.train()
print(" [*] Training finished!")
if args.phase == 'test' :
gan.test()
print(" [*] Test finished!")
if args.phase == 'cls':
gan.extract_feature()
print(" [*] Extract feature finished!")
if __name__ == '__main__':
main()