/
data_loader_split.py
113 lines (93 loc) · 4.48 KB
/
data_loader_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import numpy as np
import imageio
import logging
from nerf_sample_ray_split import RaySamplerSingleImage
import glob
logger = logging.getLogger(__package__)
########################################################################################################################
# camera coordinate system: x-->right, y-->down, z-->scene (opencv/colmap convention)
# poses is camera-to-world
########################################################################################################################
def find_files(dir, exts):
if os.path.isdir(dir):
# types should be ['*.png', '*.jpg', '*.JPG', '*.PNG']
files_grabbed = []
for ext in exts:
files_grabbed.extend(glob.glob(os.path.join(dir, ext)))
if len(files_grabbed) > 0:
files_grabbed = sorted(files_grabbed)
return files_grabbed
else:
return []
def load_data_split(basedir, scene, split, skip=1, try_load_min_depth=True, only_img_files=False, use_ray_jitter=True, resolution_level=1):
def parse_txt(filename):
assert os.path.isfile(filename)
nums = open(filename).read().split()
return np.array([float(x) for x in nums]).reshape([4, 4]).astype(np.float32)
if basedir[-1] == '/': # remove trailing '/'
basedir = basedir[:-1]
split_dir = '{}/{}/{}'.format(basedir, scene, split)
if only_img_files:
img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg', '*.JPG', '*.PNG'])
return img_files
# camera parameters files
intrinsics_files = find_files('{}/intrinsics'.format(split_dir), exts=['*.txt'])
pose_files = find_files('{}/pose'.format(split_dir), exts=['*.txt'])
logger.info('raw intrinsics_files: {}'.format(len(intrinsics_files)))
logger.info('raw pose_files: {}'.format(len(pose_files)))
intrinsics_files = intrinsics_files[::skip]
pose_files = pose_files[::skip]
cam_cnt = len(pose_files)
# img files
img_files = find_files('{}/rgb'.format(split_dir), exts=['*.png', '*.jpg', '*.JPG', '*.PNG'])
if len(img_files) > 0:
logger.info('raw img_files: {}'.format(len(img_files)))
img_files = img_files[::skip]
assert(len(img_files) == cam_cnt)
else:
img_files = [None, ] * cam_cnt
# mask files
mask_files = find_files('{}/mask'.format(split_dir), exts=['*.png', '*.jpg', '*.JPG', '*.PNG'])
if len(mask_files) > 0:
logger.info('raw mask_files: {}'.format(len(mask_files)))
mask_files = mask_files[::skip]
assert(len(mask_files) == cam_cnt)
else:
mask_files = [None, ] * cam_cnt
# min depth files
mindepth_files = find_files('{}/min_depth'.format(split_dir), exts=['*.png', '*.jpg', '*.JPG', '*.PNG'])
if try_load_min_depth and len(mindepth_files) > 0:
logger.info('raw mindepth_files: {}'.format(len(mindepth_files)))
mindepth_files = mindepth_files[::skip]
assert(len(mindepth_files) == cam_cnt)
else:
mindepth_files = [None, ] * cam_cnt
# assume all images have the same size as training image
# train_imgfile = find_files('{}/{}/train/rgb'.format(basedir, scene), exts=['*.png', '*.jpg', '*.JPG', '*.PNG'])[0]
# train_im = imageio.imread(train_imgfile)
# H, W = train_im.shape[:2]
# create ray samplers
ray_samplers = []
for i in range(cam_cnt):
intrinsics = parse_txt(intrinsics_files[i])
pose = parse_txt(pose_files[i])
# read max depth
try:
max_depth = float(open('{}/max_depth.txt'.format(split_dir)).readline().strip())
except:
max_depth = None
if img_files[i] is not None and os.path.exists(img_files[i]):
H, W = imageio.imread(img_files[i]).shape[:2]
else:
H, W = 775, 1044
print('Couldn\'t find', img_files[i], ', using img_size=', W, H)
ray_samplers.append(RaySamplerSingleImage(H=H, W=W, intrinsics=intrinsics, c2w=pose,
img_path=img_files[i],
mask_path=mask_files[i],
min_depth_path=mindepth_files[i],
max_depth=max_depth,
use_ray_jitter=use_ray_jitter,
resolution_level=resolution_level))
logger.info('Split {}, # views: {}'.format(split, cam_cnt))
return ray_samplers