Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nd gaussians support to gaussian splatting method #2568

Open
wants to merge 84 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
749059d
First pass at dataloader, set up file structure (minus the actual ren…
jake-austin Sep 8, 2023
bb90b94
Basic working, still needs to be profiled and tuned, some preliminary…
jake-austin Sep 12, 2023
8bf2121
Adding more TODOs
jake-austin Sep 12, 2023
1a214ba
Temp save
jake-austin Sep 12, 2023
e5873d8
Going back to original code for speed of testing and correctness for now
jake-austin Sep 12, 2023
9c061f2
Config fixes
brentyi Sep 13, 2023
e3a0f02
fix small bugs and allow testing on images with aspect-ratio not equa…
Zhuoyang-Pan Sep 14, 2023
0e1c8f0
Everything implemented but needs testing
jake-austin Sep 16, 2023
2ea816c
Fixed callbacks causing no splits
jake-austin Sep 16, 2023
274d8f2
Adding another TODO
jake-austin Sep 16, 2023
f6bdfac
Configify
jake-austin Sep 16, 2023
8c2b1b8
Configifying scene extent
jake-austin Sep 16, 2023
b24835d
More config
jake-austin Sep 16, 2023
35be6fd
Accounting for camera distortion, caching undistorted images
jake-austin Sep 18, 2023
75d2d14
add number of gaussians in logger; allow testing in nerfacto
Zhuoyang-Pan Sep 20, 2023
c2fef38
inhouse impl start
kerrj Sep 22, 2023
3c40314
merge main
kerrj Sep 22, 2023
c0d5e19
inhouse impl start
kerrj Sep 26, 2023
7abcedb
split/duplicate thresholds done, not the parameters themselves
kerrj Sep 27, 2023
7f46f32
adding spherical harmonics -- make sure to pull diff rast too
vye16 Sep 28, 2023
92da81e
splitting and dupping
kerrj Sep 28, 2023
2a7db9c
Merge branch 'justin/inhouse' of https://github.com/jake-austin/gauss…
kerrj Sep 28, 2023
ef7f9b7
sh scheduling, resolution scheduling, param tuning
kerrj Sep 28, 2023
d6c49d3
params, add random background
kerrj Sep 29, 2023
84b8241
fix memory leak, reset opacity exp_avg when resetting opacity
kerrj Sep 29, 2023
e57eb79
pc initialization + sh init
AdamRashid96 Sep 30, 2023
c0b8dcf
merge
AdamRashid96 Sep 30, 2023
53d8031
radii pruning
kerrj Sep 30, 2023
01efb8e
grad accum helps, use plenoxel sh code
kerrj Sep 30, 2023
4f462ea
fix order of cull and split, add random init support
kerrj Sep 30, 2023
b350a68
integrate render pipeline (WIP), fix ssim loss
kerrj Sep 30, 2023
3501e9b
working on sh
AdamRashid96 Sep 30, 2023
8dcaa09
Merge branch 'justin/inhouse' of https://github.com/jake-austin/gauss…
AdamRashid96 Sep 30, 2023
ef2c54c
fix sh by splitting learning rates + comment out random init on top o…
AdamRashid96 Oct 1, 2023
50fbb95
adding corrected splitting gradients
vye16 Oct 3, 2023
b843abf
adding ndc pixel factor for sceenspace grads
vye16 Oct 3, 2023
9f25457
small change in print output
vye16 Oct 3, 2023
2ee7e0f
nsamps option for duplicating, cropping, and depth rendering
kerrj Oct 3, 2023
4b1d6df
fixing xy diff -> tiling issue
vye16 Oct 6, 2023
732b525
fixing merge conflict
vye16 Oct 6, 2023
b762ed8
testing nd-gaussians
vye16 Oct 7, 2023
d06263e
fix bug in split function, to be used with n-d gaussians branch
kerrj Oct 10, 2023
66d0d7e
merge
kerrj Oct 10, 2023
17e95e1
more bugfix, number of gaussians doesnt go crazy anymore
kerrj Oct 10, 2023
66ddab3
wait for all cams to be seen after opacity reset for culling
kerrj Oct 10, 2023
70ade42
fix jittering in markdown in viewer beta
kerrj Oct 11, 2023
622342c
Revert "fix jittering in markdown in viewer beta"
kerrj Oct 11, 2023
90957bc
cleanup
kerrj Oct 11, 2023
c113f29
param tweaks, dont generate raybundle in viewer
kerrj Oct 12, 2023
ccf1dd5
dataloader memory leak
kerrj Oct 12, 2023
06cde5a
migrate to gsplat repo
kerrj Oct 13, 2023
cfd41a8
merge
kerrj Oct 13, 2023
6e20b7b
add depth composite, really slow streaming though
kerrj Oct 13, 2023
bef5064
params, remove normalization on quats
kerrj Oct 13, 2023
d8b76fe
lint, clean
kerrj Oct 13, 2023
f4cc901
lint
kerrj Oct 13, 2023
1cf455d
clean
kerrj Oct 13, 2023
59e06dd
more lint
kerrj Oct 13, 2023
40c05be
fix undistort intrinsics in datamanager
kerrj Oct 14, 2023
d46aede
Update nerfstudio/viewer/server/render_state_machine.py
kerrj Oct 16, 2023
46fde7c
Update nerfstudio/viewer/server/render_state_machine.py
kerrj Oct 16, 2023
8a3e247
Update nerfstudio/viewer/server/render_state_machine.py
kerrj Oct 16, 2023
bb7c7f9
quality improvements, param changes, camera optimization support, mis…
kerrj Oct 18, 2023
9385d0a
Merge branch 'gaussian-splatting' of https://github.com/nerfstudio-pr…
kerrj Oct 18, 2023
072c3ed
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Oct 18, 2023
7480242
Merge branch 'main' into gaussian-splatting
kerrj Oct 19, 2023
d68138d
Merge branch 'main' of https://github.com/nerfstudio-project/nerfstudio
kerrj Oct 19, 2023
d1d7584
merge
kerrj Oct 20, 2023
b148e2e
static rendering never triggered
kerrj Oct 20, 2023
ded3025
Merge branch 'justin/multiclient-bug' into gaussian-splatting
kerrj Oct 20, 2023
09d5a2a
index on gaussian-splatting: 8a3e2473 Update nerfstudio/viewer/server…
ginazhouhuiwu Oct 20, 2023
2dc7f95
WIP on gaussian-splatting: 8a3e2473 Update nerfstudio/viewer/server/r…
ginazhouhuiwu Oct 20, 2023
d751ae8
convert pix to screen percent for split/cull size
kerrj Oct 21, 2023
8ecdee7
merge
kerrj Oct 21, 2023
8e753b8
Initial eval with updated get_metrics_dict
ginazhouhuiwu Oct 21, 2023
cd119d2
Merge branch 'gaussian-splatting' of https://github.com/nerfstudio-pr…
kerrj Oct 21, 2023
c044d66
Gaussian Splatting export function (#2535)
akristoffersen Oct 23, 2023
332f280
bugfixes, param tweaks
kerrj Oct 23, 2023
5a18444
cleanup
kerrj Oct 23, 2023
c5de49d
Merge branch 'gaussian-splatting' of https://github.com/nerfstudio-pr…
kerrj Oct 23, 2023
e3822de
Gaussian splatting process data fixes (#2554)
machenmusik Oct 24, 2023
36e9034
Ply dataparser (#2557)
CardiacMangoes Oct 25, 2023
e21b5e3
adding nd gaussians to gsplat method
ethanweber Oct 29, 2023
028b1d5
revert trainer code
ethanweber Oct 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"python.envFile": "${workspaceFolder}/.env",
"python.formatting.provider": "none",
"black-formatter.args": ["--line-length=120"],
"python.linting.pylintEnabled": false,
"python.linting.pylintEnabled": true,
"python.linting.flake8Enabled": false,
"python.linting.enabled": true,
"python.testing.unittestEnabled": false,
Expand Down
15 changes: 13 additions & 2 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from nerfstudio.utils import poses as pose_utils
from nerfstudio.engine.optimizers import OptimizerConfig
from nerfstudio.engine.schedulers import SchedulerConfig
from nerfstudio.cameras.cameras import Cameras


@dataclass
Expand All @@ -44,10 +45,10 @@ class CameraOptimizerConfig(InstantiateConfig):
mode: Literal["off", "SO3xR3", "SE3"] = "off"
"""Pose optimization strategy to use. If enabled, we recommend SO3xR3."""

trans_l2_penalty: float = 1e-2
trans_l2_penalty: float = 1e-4
"""L2 penalty on translation parameters."""

rot_l2_penalty: float = 1e-3
rot_l2_penalty: float = 1e-4
"""L2 penalty on rotation parameters."""

optimizer: Optional[OptimizerConfig] = field(default=None)
Expand Down Expand Up @@ -146,6 +147,16 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()

def apply_to_camera(self, camera: Cameras) -> None:
"""Apply the pose correction to the raybundle"""
if self.config.mode != "off":
assert camera.metadata is not None, "Must provide id of camera in its metadata"
assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
camera_idx = camera.metadata["cam_idx"]
adj = self([camera_idx]) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj)

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
if self.config.mode != "off":
Expand Down
3 changes: 2 additions & 1 deletion nerfstudio/configs/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class LocalWriterConfig(InstantiateConfig):
writer.EventName.VIS_RAYS_PER_SEC,
writer.EventName.TEST_RAYS_PER_SEC,
writer.EventName.ETA,
writer.EventName.GAUSSIAN_NUM,
)
"""specifies which stats will be logged/printed to terminal"""
max_log_size: int = 10
Expand Down Expand Up @@ -144,7 +145,7 @@ class ViewerConfig(PrintableConfig):
"""Whether to kill the training job when it has completed. Note this will stop rendering in the viewer."""
image_format: Literal["jpeg", "png"] = "jpeg"
"""Image format viewer should use; jpeg is lossy compression, while png is lossless."""
jpeg_quality: int = 90
jpeg_quality: int = 70
"""Quality tradeoff to use for jpeg compression."""
make_share_url: bool = False
"""Viewer beta feature: print a shareable URL. `vis` must be set to viewer_beta; this flag is otherwise ignored."""
68 changes: 67 additions & 1 deletion nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Dict

import tyro
from nerfstudio.data.pixel_samplers import PairPixelSamplerConfig

from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.configs.base_config import ViewerConfig
Expand All @@ -49,6 +48,7 @@
from nerfstudio.data.datasets.depth_dataset import DepthDataset
from nerfstudio.data.datasets.sdf_dataset import SDFDataset
from nerfstudio.data.datasets.semantic_dataset import SemanticDataset
from nerfstudio.data.pixel_samplers import PairPixelSamplerConfig
from nerfstudio.engine.optimizers import AdamOptimizerConfig, RAdamOptimizerConfig
from nerfstudio.engine.schedulers import (
CosineDecaySchedulerConfig,
Expand All @@ -59,8 +59,10 @@
from nerfstudio.field_components.temporal_distortions import TemporalDistortionKind
from nerfstudio.fields.sdf_field import SDFFieldConfig
from nerfstudio.models.depth_nerfacto import DepthNerfactoModelConfig
from nerfstudio.models.gaussian_splatting import GaussianSplattingModelConfig
from nerfstudio.models.generfacto import GenerfactoModelConfig
from nerfstudio.models.instant_ngp import InstantNGPModelConfig
from nerfstudio.data.dataparsers.colmap_dataparser import ColmapDataParserConfig
from nerfstudio.models.mipnerf import MipNerfModel
from nerfstudio.models.nerfacto import NerfactoModelConfig
from nerfstudio.models.neus import NeuSModelConfig
Expand All @@ -69,6 +71,7 @@
from nerfstudio.models.tensorf import TensoRFModelConfig
from nerfstudio.models.vanilla_nerf import NeRFModel, VanillaModelConfig
from nerfstudio.pipelines.base_pipeline import VanillaPipelineConfig
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanagerConfig
from nerfstudio.pipelines.dynamic_batch import DynamicBatchPipelineConfig
from nerfstudio.plugins.registry import discover_methods

Expand All @@ -87,6 +90,7 @@
"generfacto": "Generative Text to NeRF model",
"neus": "Implementation of NeuS. (slow)",
"neus-facto": "Implementation of NeuS-Facto. (slow)",
"gaussian-splatting": "Gaussian Splatting model",
}

method_configs["nerfacto"] = TrainerConfig(
Expand Down Expand Up @@ -594,6 +598,68 @@
vis="viewer",
)

method_configs["gaussian-splatting"] = TrainerConfig(
method_name="gaussian-splatting",
steps_per_eval_image=100,
steps_per_eval_batch=100,
steps_per_save=2000,
steps_per_eval_all_images=100000,
max_num_iterations=30000,
mixed_precision=False,
gradient_accumulation_steps={"camera_opt": 100, "color": 10, "shs": 10},
pipeline=VanillaPipelineConfig(
datamanager=FullImageDatamanagerConfig(
dataparser=ColmapDataParserConfig(load_3D_points=True),
),
model=GaussianSplattingModelConfig(),
),
optimizers={
"xyz": {
"optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1.6e-6,
max_steps=30000,
),
},
"color": {
"optimizer": AdamOptimizerConfig(lr=2.5e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-3,
max_steps=30000,
),
},
"shs": {
"optimizer": AdamOptimizerConfig(lr=2.5e-3 / 20, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-3 / 20,
max_steps=30000,
),
},
"opacity": {
"optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15),
"scheduler": None,
},
"nd_values": {
"optimizer": AdamOptimizerConfig(lr=2.5e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=1e-3,
max_steps=30000,
),
},
"scaling": {
"optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=1e-3, max_steps=30000),
},
"rotation": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
vis="viewer_beta",
)


def merge_methods(methods, method_descriptions, new_methods, new_descriptions, overwrite=True):
"""Merge new methods and descriptions into existing methods and descriptions.
Expand Down
27 changes: 13 additions & 14 deletions nerfstudio/data/datamanagers/base_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from abc import abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from functools import cached_property
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
ForwardRef,
Generic,
List,
Literal,
Expand All @@ -35,9 +36,8 @@
Type,
Union,
cast,
ForwardRef,
get_origin,
get_args,
get_origin,
)

import torch
Expand All @@ -47,17 +47,17 @@
from typing_extensions import TypeVar

from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig
from nerfstudio.cameras.cameras import CameraType
from nerfstudio.cameras.cameras import Cameras, CameraType
from nerfstudio.cameras.rays import RayBundle
from nerfstudio.configs.base_config import InstantiateConfig
from nerfstudio.configs.dataparser_configs import AnnotatedDataParserUnion
from nerfstudio.data.dataparsers.base_dataparser import DataparserOutputs
from nerfstudio.data.dataparsers.blender_dataparser import BlenderDataParserConfig
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.pixel_samplers import (
PatchPixelSamplerConfig,
PixelSampler,
PixelSamplerConfig,
PatchPixelSamplerConfig,
)
from nerfstudio.data.utils.dataloaders import (
CacheDataloader,
Expand All @@ -67,9 +67,8 @@
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes
from nerfstudio.model_components.ray_generators import RayGenerator
from nerfstudio.utils.misc import IterableWrapper
from nerfstudio.utils.misc import IterableWrapper, get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.misc import get_orig_class


def variable_res_collate(batch: List[Dict]) -> Dict:
Expand Down Expand Up @@ -131,7 +130,7 @@ class DataManager(nn.Module):
To get data, use the next_train and next_eval functions.
This data manager's next_train and next_eval methods will return 2 things:

1. A Raybundle: This will contain the rays we are sampling, with latents and
1. A rays: This will contain the rays/camera we are sampling, with latents and
conditionals attached (everything needed at inference)
2. A "batch" of auxiliary information: This will contain the mask, the ground truth
pixels, etc needed to actually train, score, etc the model
Expand Down Expand Up @@ -246,7 +245,7 @@ def setup_eval(self):
"""Sets up the data manager for evaluation"""

@abstractmethod
def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
def next_train(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the train data manager.

Args:
Expand All @@ -258,25 +257,25 @@ def next_train(self, step: int) -> Tuple[RayBundle, Dict]:
raise NotImplementedError

@abstractmethod
def next_eval(self, step: int) -> Tuple[RayBundle, Dict]:
def next_eval(self, step: int) -> Tuple[Union[RayBundle, Cameras], Dict]:
"""Returns the next batch of data from the eval data manager.

Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the ray bundle for the image, and a dictionary of additional batch information
A tuple of the ray/camera for the image, and a dictionary of additional batch information
such as the groundtruth image.
"""
raise NotImplementedError

@abstractmethod
def next_eval_image(self, step: int) -> Tuple[int, RayBundle, Dict]:
def next_eval_image(self, step: int) -> Tuple[int, Union[RayBundle, Cameras], Dict]:
"""Retrieve the next eval image.

Args:
step: the step number of the eval image to retrieve
Returns:
A tuple of the step number, the ray bundle for the image, and a dictionary of
A tuple of the step number, the ray/camera for the image, and a dictionary of
additional batch information such as the groundtruth image.
"""
raise NotImplementedError
Expand Down Expand Up @@ -313,7 +312,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:

@dataclass
class VanillaDataManagerConfig(DataManagerConfig):
"""A basic data manager"""
"""A basic data manager for a ray-based model"""

_target: Type = field(default_factory=lambda: VanillaDataManager)
"""Target class to instantiate."""
Expand Down