/
train_val_split.py
40 lines (28 loc) · 1.19 KB
/
train_val_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# -*- coding: utf-8 -*-
# @Time : 2020/8/6 10:46
# @Author : runze.wang
import os
import shutil
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--folder_path', default='/home/gdp/data/img_celeba', type=str,help='The folder path')
parser.add_argument('--train_folder', default='/home/gdp/data/img_celeba/train_folder', type=str)
parser.add_argument('--validation_folder', default='/home/gdp/data/img_celeba/validation_folder', type=str)
def copy_file(source_dir_list, target_dir):
if not os.path.exists(target_dir):
os.makedirs(target_dir)
for dir in source_dir_list:
shutil.copy(dir, target_dir)
if __name__ == "__main__":
args = parser.parse_args()
dirs = os.listdir(args.folder_path)
train_name_list = []
validation_name_list = []
for i in range(len(dirs)):
if dirs[i].split('.')[-1] == 'jpg':
if i <= int(0.8*len(dirs)):
train_name_list.append(os.path.join(args.folder_path, dirs[i]))
else:
validation_name_list.append(os.path.join(args.folder_path, dirs[i]))
copy_file(train_name_list, args.train_folder)
copy_file(validation_name_list, args.validation_folder)