Skip to content

Commit

Permalink
quality improvements, param changes, camera optimization support, mis…
Browse files Browse the repository at this point in the history
…c updates
  • Loading branch information
kerrj committed Oct 18, 2023
1 parent 40c05be commit bb7c7f9
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 122 deletions.
13 changes: 8 additions & 5 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class CameraOptimizerConfig(InstantiateConfig):
mode: Literal["off", "SO3xR3", "SE3"] = "off"
"""Pose optimization strategy to use. If enabled, we recommend SO3xR3."""

trans_l2_penalty: float = 1e-2
trans_l2_penalty: float = 1e-4
"""L2 penalty on translation parameters."""

rot_l2_penalty: float = 1e-3
rot_l2_penalty: float = 1e-4
"""L2 penalty on rotation parameters."""

optimizer: Optional[OptimizerConfig] = field(default=None)
Expand Down Expand Up @@ -150,9 +150,12 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
def apply_to_camera(self, camera: Cameras) -> None:
"""Apply the pose correction to the raybundle"""
if self.config.mode != "off":
correction_matrices = self(raybundle.camera_indices.squeeze()) # type: ignore
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()
assert camera.metadata is not None, "Must provide id of camera in its metadata"
assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
camera_idx = camera.metadata["cam_idx"]
adj = self([camera_idx]) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj)

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
Expand Down
21 changes: 10 additions & 11 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,8 @@
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer",
)

from dataclasses import field
from collections import defaultdict
method_configs["gaussian-splatting"] = TrainerConfig(
method_name="gaussian-splatting",
steps_per_eval_image=10,
Expand All @@ -606,7 +607,7 @@
steps_per_eval_all_images=1000000, # set to a very large model so we don't eval with all images
max_num_iterations=30000,
mixed_precision=False,
gradient_accumulation_steps=1,
gradient_accumulation_steps = {'camera_opt': 100},
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
dataparser=ColmapDataParserConfig(load_3D_points=True),
Expand All @@ -618,7 +619,7 @@
"optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1.6e-6,
max_steps=15000,
max_steps=30000,
),
},
"color": {
Expand All @@ -641,17 +642,15 @@
},
"scaling": {
"optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-3,
max_steps=15000,
),
"scheduler": None
},
"rotation": {
"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-4,
max_steps=15000,
),
"scheduler": None
},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
Expand Down
42 changes: 26 additions & 16 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,26 @@ def cache_images(self, cache_images_option):
0,
]
)
newK,roi = cv2.getOptimalNewCameraMatrix(K,distortion_params,(image.shape[1],image.shape[0]),0)
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK)
#crop the image and update the intrinsics accordingly
x,y,w,h = roi
image = image[y:y+h,x:x+w]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
K = newK
#update the width, height
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h

elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(K, distortion_params, (image.shape[1],image.shape[0]), np.eye(3), balance=0)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion_params, np.eye(3), newK, (image.shape[1],image.shape[0]), cv2.CV_32FC1)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
K = newK
Expand Down Expand Up @@ -204,21 +208,25 @@ def cache_images(self, cache_images_option):
0,
]
)
newK,roi = cv2.getOptimalNewCameraMatrix(K,distortion_params,(image.shape[1],image.shape[0]),0)
newK, roi = cv2.getOptimalNewCameraMatrix(K, distortion_params, (image.shape[1], image.shape[0]), 0)
image = cv2.undistort(image, K, distortion_params, None, newK)
#crop the image and update the intrinsics accordingly
x,y,w,h = roi
image = image[y:y+h,x:x+w]
# crop the image and update the intrinsics accordingly
x, y, w, h = roi
image = image[y : y + h, x : x + w]
K = newK
#update the width, height
# update the width, height
self.train_dataset.cameras.width[i] = w
self.train_dataset.cameras.height[i] = h
elif camera.camera_type.item() == CameraType.FISHEYE.value:
distortion_params = np.array(
[distortion_params[0], distortion_params[1], distortion_params[2], distortion_params[3]]
)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(K, distortion_params, (image.shape[1],image.shape[0]), np.eye(3), balance=0)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(K, distortion_params, np.eye(3), newK, (image.shape[1],image.shape[0]), cv2.CV_32FC1)
newK = cv2.fisheye.estimateNewCameraMatrixForUndistortRectify(
K, distortion_params, (image.shape[1], image.shape[0]), np.eye(3), balance=0
)
map1, map2 = cv2.fisheye.initUndistortRectifyMap(
K, distortion_params, np.eye(3), newK, (image.shape[1], image.shape[0]), cv2.CV_32FC1
)
# and then remap:
image = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
K = newK
Expand Down Expand Up @@ -325,10 +333,12 @@ def next_train(self, step: int) -> Tuple[Cameras, Dict]:

data = deepcopy(self.cached_train[image_idx])
data["image"] = data["image"].to(self.device)
data['cam_idx'] = torch.tensor(image_idx).to(self.device)

assert len(self.train_dataset.cameras.shape) == 1, "Assumes single batch dimension"
camera = self.train_dataset.cameras[image_idx : image_idx + 1].to(self.device)
if camera.metadata is None:
camera.metadata = {}
camera.metadata["cam_idx"] = image_idx
return camera, data

def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
Expand Down
21 changes: 21 additions & 0 deletions nerfstudio/engine/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def zero_grad_all(self) -> None:
for _, optimizer in self.optimizers.items():
optimizer.zero_grad()

def zero_grad_some(self,param_groups: List[str]) -> None:
"""Zero the gradients for the given parameter groups."""
for param_group in param_groups:
optimizer = self.optimizers[param_group]
optimizer.zero_grad()

def optimizer_scaler_step_all(self, grad_scaler: GradScaler) -> None:
"""Take an optimizer step using a grad scaler.
Expand All @@ -149,6 +155,21 @@ def optimizer_scaler_step_all(self, grad_scaler: GradScaler) -> None:
if any(any(p.grad is not None for p in g["params"]) for g in optimizer.param_groups):
grad_scaler.step(optimizer)

def optimizer_scaler_step_some(self, grad_scaler: GradScaler, param_groups: List[str]) -> None:
"""Take an optimizer step using a grad scaler ONLY on the specified param groups.
Args:
grad_scaler: GradScaler to use
"""
for param_group in param_groups:
optimizer = self.optimizers[param_group]
max_norm = self.config[param_group]["optimizer"].max_norm
if max_norm is not None:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters[param_group], max_norm)
if any(any(p.grad is not None for p in g["params"]) for g in optimizer.param_groups):
grad_scaler.step(optimizer)

def optimizer_step_all(self) -> None:
"""Run step for all optimizers."""
for param_group, optimizer in self.optimizers.items():
Expand Down
41 changes: 25 additions & 16 deletions nerfstudio/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from dataclasses import dataclass, field
from pathlib import Path
from threading import Lock
from typing import Dict, List, Literal, Optional, Tuple, Type, cast

from typing import Dict, List, Literal, Optional, Tuple, Type, cast,Union,Dict,DefaultDict
from collections import defaultdict
import torch
from nerfstudio.configs.experiment_config import ExperimentConfig
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
Expand Down Expand Up @@ -81,8 +81,8 @@ class TrainerConfig(ExperimentConfig):
"""Path to checkpoint file."""
log_gradients: bool = False
"""Optionally log gradients during training"""
gradient_accumulation_steps: int = 1
"""Number of steps to accumulate gradients over."""
gradient_accumulation_steps: Dict = field(default_factory=lambda: {})
"""Number of steps to accumulate gradients over. Contains a mapping of {param_group:num}"""


class Trainer:
Expand Down Expand Up @@ -119,7 +119,8 @@ def __init__(self, config: TrainerConfig, local_rank: int = 0, world_size: int =
self.mixed_precision: bool = self.config.mixed_precision
self.use_grad_scaler: bool = self.mixed_precision or self.config.use_grad_scaler
self.training_state: Literal["training", "paused", "completed"] = "training"
self.gradient_accumulation_steps: int = self.config.gradient_accumulation_steps
self.gradient_accumulation_steps: DefaultDict = defaultdict(lambda: 1)
self.gradient_accumulation_steps.update(self.config.gradient_accumulation_steps)

if self.device == "cpu":
self.mixed_precision = False
Expand Down Expand Up @@ -462,18 +463,26 @@ def train_iteration(self, step: int) -> TRAIN_INTERATION_OUTPUT:
step: Current training step.
"""

self.optimizers.zero_grad_all()
needs_zero = [group for group in self.optimizers.parameters.keys() if step % self.gradient_accumulation_steps[group] == 0]
self.optimizers.zero_grad_some(needs_zero)
cpu_or_cuda_str: str = self.device.split(":")[0]
assert (
self.gradient_accumulation_steps > 0
), f"gradient_accumulation_steps must be > 0, not {self.gradient_accumulation_steps}"
for _ in range(self.gradient_accumulation_steps):
with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
loss = functools.reduce(torch.add, loss_dict.values())
loss /= self.gradient_accumulation_steps
self.grad_scaler.scale(loss).backward() # type: ignore
self.optimizers.optimizer_scaler_step_all(self.grad_scaler)
# assert (
# self.gradient_accumulation_steps > 0
# ), f"gradient_accumulation_steps must be > 0, not {self.gradient_accumulation_steps}"
# for _ in range(self.gradient_accumulation_steps):
# with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
# _, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
# loss = functools.reduce(torch.add, loss_dict.values())
# loss /= self.gradient_accumulation_steps
# self.grad_scaler.scale(loss).backward() # type: ignore
# self.optimizers.optimizer_scaler_step_all(self.grad_scaler)

with torch.autocast(device_type=cpu_or_cuda_str, enabled=self.mixed_precision):
_, loss_dict, metrics_dict = self.pipeline.get_train_loss_dict(step=step)
loss = functools.reduce(torch.add, loss_dict.values())
self.grad_scaler.scale(loss).backward() # type: ignore
needs_step = [group for group in self.optimizers.parameters.keys() if step % self.gradient_accumulation_steps[group] == self.gradient_accumulation_steps[group]-1]
self.optimizers.optimizer_scaler_step_some(self.grad_scaler,needs_step)

if self.config.log_gradients:
total_grad = 0
Expand Down
16 changes: 16 additions & 0 deletions nerfstudio/model_components/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,22 @@ def scale_gradients_by_distance_squared(
out[key], _ = cast(Tuple[Tensor, Tensor], _GradientScaler.apply(value, scaling))
return out

def scale_gauss_gradients_by_distance_squared(
gauss_outs: torch.nn.Parameter,
gauss_depths: Float[Tensor,"n_gauss 1"],
padding_eps: float = .1,
far_dist: float = 1,
) -> torch.nn.Parameter:
"""
Scale gradients by the gaussian distance to the pixel
"""
scaling = (torch.square(gauss_depths/far_dist) + padding_eps).clamp(0, 1)
num_dims_to_add = len(gauss_outs.shape) - 1
scaling = scaling.view(-1, *([1] * num_dims_to_add))
out, _ = cast(Tuple[Tensor, Tensor], _GradientScaler.apply(gauss_outs, scaling))
return out


def depth_ranking_loss(rendered_depth, gt_depth):
"""
Expand Down

1 comment on commit bb7c7f9

@16Huzeyu
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In gaussian_splatting.py 505line what do you mean "#currently relies on the branch vickie/camera-grads",because I can't find this branch in nerfstudio

Please sign in to comment.