Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading multiple pipelines and trainers #161

Open
ttsesm opened this issue Jul 5, 2023 · 5 comments
Open

Loading multiple pipelines and trainers #161

ttsesm opened this issue Jul 5, 2023 · 5 comments

Comments

@ttsesm
Copy link

ttsesm commented Jul 5, 2023

          @ttsesm 

You can add multiple objects to the scene graph like this:

from wisp.framework import WispState
from wisp.renderer.core.api import add_to_scene_graph

wisp_state = WispState() # Your global shared info here.. Created once per app
nerf_pipeline1 = Pipeline(nef=nef1, tracer=tracer1) # See main_nerf.py for an example of creating a Pipeline..
nerf_pipeline2 = Pipeline(nef=nef2, tracer=tracer2) # See main_nerf.py for an example of creating a Pipeline..

# Optional NeRF args are args that NeuralRadianceFieldPackedRenderer.__init__ takes as input:
# https://github.com/NVIDIAGameWorks/kaolin-wisp/blob/main/wisp/renderer/core/renderers/radiance_pipeline_renderer.py#L26
# batch_size is an optional setup arg here which hints the visualizer how many rays can be processed at once
# e.g. this is the pipeline's batch_size used for inference time
nerf_specific_args = dict(batch_size=2**14)

# Add object to scene graph: if interactive mode is on, this will make sure the visualizer can display it.
add_to_scene_graph(state=wisp_state, name="My NeRF", obj=nerf_pipeline1, **nerf_specific_args)
add_to_scene_graph(state=wisp_state, name="Another NeRF", obj=nerf_pipeline2, **nerf_specific_args)

Each object has its own ObjectTransform so you can control their orientation, dimensions and location around the scene.

Some more explanation about the scene graph and NeuralRadianceFieldPackedRenderer is available in the docs:
https://kaolin-wisp.readthedocs.io/en/latest/pages/renderer.html#the-scenegraph

Keep in mind that the app runs in an infinite loop, add_to_scene_graph is an async request. The object actually gets added to the scene graph when the next frame is drawn.

If you have further questions please open a separate issue, it makes it easier for me to track :)

Originally posted by @orperel in #35 (comment)

@ttsesm
Copy link
Author

ttsesm commented Jul 5, 2023

@orperel as a continuation of the above response I have tried the following.

First, to give you a brief description of what I am trying to achieve so that we are in the same page. My goal is to load and train multiple sdf pipelines.

Then I am using the default app/nglod demo which I am trying to modify for the corresponding task. So at the moment I am using the default nglod_octree.yaml as a common configuration file for both pipelines while I am loading the same .obj file for both sdf pipelines. So my script as for now has as follows:

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.


import os
import logging
import torch
from typing import Optional
from wisp.app_utils import default_log_setup
from wisp.config import parse_config, configure, autoconfig, instantiate, print_config, get_config_target
from wisp.framework import WispState
from wisp.datasets import SDFDataset, MeshSampledSDFDataset, OctreeSampledSDFDataset
from wisp.accelstructs import OctreeAS, AxisAlignedBBoxAS
from wisp.models.grids import OctreeGrid, TriplanarGrid, HashGrid
from wisp.models.nefs import NeuralSDF
from wisp.models.pipeline import Pipeline
from wisp.tracers import PackedSDFTracer
from wisp.trainers import SDFTrainer, ConfigSDFTrainer
from wisp.trainers.tracker import Tracker, ConfigTracker
from wisp.renderer.core.api import add_to_scene_graph



@configure
class SDFAppConfig:
    """ A script for training neural SDF variants with grid backbones.
    See: Takikawa et al. 2021 - "Neural Geometric Level of Detail: Real-time Rendering with Implicit 3D Shapes".
    """
    blas: autoconfig(OctreeAS.from_mesh, AxisAlignedBBoxAS)
    """ Bottom Level Acceleration structure used by the neural field grid to track occupancy, accelerate queries. """
    grid: autoconfig(OctreeGrid, HashGrid.from_geometric, HashGrid.from_octree, TriplanarGrid)
    """ Feature grid used by the neural field. Grids are located in `wisp.models.grids` """
    nef: autoconfig(NeuralSDF)
    """ Signed distance field configuration, including the feature grid, a decoder and optional embedder.
    NeuralSDF maps 3D coordinates -> SDF values. Uses spatial feature grids internally for faster feature interpolation.
    """
    tracer: autoconfig(PackedSDFTracer)
    """ Tracers are responsible for taking input rays, marching them through the neural field to render 
    an output RenderBuffer. In this app, the tracer is only used for rendering during test time.
    """
    dataset: autoconfig(MeshSampledSDFDataset, OctreeSampledSDFDataset)
    """ SDF dataset used by the trainer. """
    trainer: ConfigSDFTrainer
    """ Configuration for trainer used to optimize the neural sdf. """
    tracker: ConfigTracker
    """ Experiments tracker for reporting to tensorboard & wandb, creating visualizations and aggregating metrics. """
    log_level: int = logging.INFO
    """ Sets the global output log level: logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR, logging.CRITICAL """
    pretrained: Optional[str] = None
    """ If specified, a pretrained model will be loaded from this path. None will create a new model. """
    device: str = 'cuda'
    """ Device used to run the optimization """
    interactive: bool = os.environ.get('WISP_HEADLESS') != '1'
    """ Set to --interactive=True for interactive mode which uses the GUI.
    The default value is set according to the env variable WISP_HEADLESS, if available. 
    Otherwise, interactive mode is on by default. """


cfg = parse_config(SDFAppConfig, yaml_arg='--config')  # Obtain args by priority: cli args > config yaml > config defaults
device = torch.device(cfg.device)
default_log_setup(cfg.log_level)
if cfg.interactive:
    cfg.tracer.bg_color = 'black'
    cfg.trainer.render_every = -1
    # cfg.trainer.save_every = -1
    cfg.trainer.valid_every = -1
print_config(cfg)


# Create model from scratch
# blas is the occupancy acceleration structure, possibly initialized sparsely from a mesh
blas = instantiate(cfg.blas)
grid = instantiate(cfg.grid, blas=blas)  # A grid keeps track of both features and occupancy
nef = instantiate(cfg.nef, grid=grid)    # nef here is a SDF which uses a grid as the backbone
# tracer here is used to efficiently render the SDF at test time (note: not used during optimization).
# Wisp's implementation of Neural Geometric LOD uses PackedSDFTracer to trace the neural field:
# - Packed: each ray yields a custom number of sphere tracing steps,
#   which are therefore packed in a flat form within a tensor,
#   see: https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#packed
# - SDF: Signed Distance Function
tracer = instantiate(cfg.tracer)
pipeline = Pipeline(nef=nef, tracer=tracer)  # Binds neural field and tracer together to a single SDF callable
pipeline1 = Pipeline(nef=nef, tracer=tracer)  # Binds neural field and tracer together to a single SDF callable


# Loads a SDF dataset comprising of sdf samples generated from an existing mesh.
# MeshSampledSDFDataset - refers to samples generated directly from the mesh surface. This dataset is decoupled from
#   the optimized model and it's blas (occupancy structure).
# OctreeSampledSDFDataset - refers to samples generated from an octree, initialized from a mesh.
#   This dataset has the benefit of limiting the sampling region to areas which are actually occupied by the mesh.
#   It also allows for equal distribution of samples per the octree cells.
if get_config_target(cfg.dataset) is OctreeSampledSDFDataset:
    assert OctreeSampledSDFDataset.supports_blas(pipeline.nef.grid.blas)
train_dataset: SDFDataset = instantiate(cfg.dataset, occupancy_struct=blas)

# Joint trainer / app state - scene_state contains various global definitions
exp_name: str = cfg.trainer.exp_name
scene_state: WispState = WispState()
scene_graph = scene_state.graph
tracker = Tracker(cfg=cfg.tracker, exp_name=exp_name)
trainer = SDFTrainer(cfg=cfg.trainer,
                     pipeline=pipeline,
                     train_dataset=train_dataset,
                     tracker=tracker,
                     device=device,
                     scene_state=scene_state)

trainer1 = SDFTrainer(cfg=cfg.trainer,
                     pipeline=pipeline1,
                     train_dataset=train_dataset,
                     tracker=tracker,
                     device=device,
                     scene_state=scene_state)

add_to_scene_graph(state=scene_state, name="obj1", obj=pipeline)
add_to_scene_graph(state=scene_state, name="obj2", obj=pipeline1)

def joint_train_step():
   trainer.iterate()
   trainer1.iterate()

# The trainer is responsible for managing the optimization life-cycles and can be operated in 2 modes:
# - Headless, which will run the train() function until all training steps are exhausted.
# - Interactive mode, which uses the gui. In this case, an OptimizationApp uses events to prompt the trainer to
#   take training steps, while also taking care to render output to users (see: iterate()).
#   In interactive mode, trainers can also share information with the app through the scene_state (WispState object).

from wisp.renderer.app.optimization_app import OptimizationApp
scene_state.renderer.device = trainer.device  # Use same device for trainer and app renderer
app = OptimizationApp(wisp_state=scene_state, trainer_step_func=trainer.joint_train_step(), experiment_name=exp_name)
app.run()  # Run in interactive mode

The above gives me the following error:

2023-07-05 16:00:12,768|    INFO| Active LODs: [2, 3, 4, 5, 6, 7]
2023-07-05 16:00:12,805|    INFO| Built dual octree and trinkets
2023-07-05 16:00:12,805|    INFO| # Feature Vectors: 58820
2023-07-05 16:00:12,813|    INFO| Pyramid:tensor([    1,     8,    32,   112,   380,  1340,  5362, 20337,     0],
       dtype=torch.int32)
2023-07-05 16:00:12,813|    INFO| Pyramid Dual: tensor([    8,    27,    82,   262,   904,  3012, 10766, 43788,     0],
       dtype=torch.int32)
2023-07-05 16:00:12,833|    INFO| Computing SDFs for entire samples pool (may take a while)..
2023-07-05 16:00:38,780|    INFO| Total Samples in Pool: 3253920
2023-07-05 16:00:38,780|    INFO| Resampling 500000 samples..
2023-07-05 16:00:38,811|    INFO| Using NVIDIA GeForce RTX 2080 SUPER with CUDA v11.1
2023-07-05 16:00:38,811|    INFO| Total number of parameters: 943809
2023-07-05 16:00:38,814|    INFO| Using NVIDIA GeForce RTX 2080 SUPER with CUDA v11.1
2023-07-05 16:00:38,814|    INFO| Total number of parameters: 943809
2023-07-05 16:00:38,818|    INFO| No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'
[i] Using PYGLFW_IMGUI (GL 3.3)
2023-07-05 16:00:39,541|    INFO| [i] Using PYGLFW_IMGUI (GL 3.3)
[i] Running at 60 frames/second
2023-07-05 16:00:39,560|    INFO| [i] Running at 60 frames/second
/home/ttsesm/Development/spc/venv_spc/lib/python3.9/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/home/ttsesm/Development/spc/venv_spc/lib/python3.9/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  warnings.warn(
Traceback (most recent call last):
  File "/Development/spc/app/nglod/main_nglod_test.py", line 137, in <module>
    app.run()  # Run in interactive mode
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 260, in run
    app.run()   # App clock should always run as frequently as possible (background tasks should not be limited)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 362, in run
    run(duration, framecount)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 344, in run
    count = __backend__.process(dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/backends/backend_glfw_imgui.py", line 448, in process
    window.dispatch_event('on_draw', dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/event.py", line 396, in dispatch_event
    if getattr(self, event_type)(*args):
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 550, in on_draw
    self.render()     # Render objects uploaded to GPU
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 36, in _enable_amp
    return func(self, *args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 518, in render
    img, depth_img = self.render_canvas(self.render_core, dt, self.canvas_dirty)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 407, in render_canvas
    renderbuffer = render_core.render(time_delta, force_render)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 30, in _enable_amp
    return func(self, *args, **kwargs)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 222, in render
    rb = self._render_payload(payload, force_render)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 352, in _render_payload
    out_rb = out_rb.blend(rb, channel_kit=self.state.graph.channels)
  File "/Development/spc/wisp/core/render_buffer.py", line 250, in blend
    out = blend(c1, c2, alpha1, alpha2)
  File "/Development/spc/wisp/core/channel_fn.py", line 219, in blend_alpha_slerp
    c2_weight = (torch.sin((1.0 - t) * omega) / sin_omega).unsqueeze(1)
RuntimeError: The size of tensor a (1600) must match the size of tensor b (1200) at non-singleton dimension 1

So apparently I am doing something wrong. I am not sure whether the add_to_scene_graph() is necessary from the moment that I creating two trainers. If I remove the add_to_scene_graph() my script runs but I am getting only one object:
image
However, I have the feeling that this is happening because both objects have the same name? so they overwrite each other.

@ttsesm
Copy link
Author

ttsesm commented Jul 7, 2023

@orperel I am trying to use a similar approach as in spc_browser example. Thus, I've create a widget_sdf_selector which I am using to create a pipeline for each of my .obj files as similarly is done for every .npz. Then I am passing the config options with the yaml file. However, it fails with the following error:

2023-07-07 17:10:28,858|    INFO| [i] Running at 60 frames/second
Traceback (most recent call last):
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 92, in _call_target
    return _target_(*args, **kwargs)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra_zen/funcs.py", line 100, in zen_processing
    return obj(*args, **kwargs)
  File "Development/spc/wisp/accelstructs/octree_as.py", line 92, in from_mesh
    vertices, faces = mesh_ops.normalize(vertices, faces, 'sphere')
  File "/Development/spc/wisp/ops/mesh/normalize.py", line 30, in normalize
    V_max, _ = torch.max(V, dim=0)
IndexError: max(): Expected reduction dim 0 to have non-zero size.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 260, in run
    app.run()   # App clock should always run as frequently as possible (background tasks should not be limited)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 362, in run
    run(duration, framecount)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 344, in run
    count = __backend__.process(dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/backends/backend_glfw_imgui.py", line 448, in process
    window.dispatch_event('on_draw', dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/event.py", line 396, in dispatch_event
    if getattr(self, event_type)(*args):
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 550, in on_draw
    self.render()     # Render objects uploaded to GPU
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 36, in _enable_amp
    return func(self, *args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 489, in render
    self.render_gui(self.wisp_state)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 388, in render_gui
    widget.paint(state)
  File "/Development/spc/app/sdf_demo/widget_sdf_selector.py", line 97, in paint
    sdf_pipeline = self.create_pipeline(available_files[selected_file_idx], device=device)
  File "/Development/spc/app/sdf_demo/widget_sdf_selector.py", line 52, in create_pipeline
    blas = instantiate(self.cfg.blas)
  File "/Development/spc/wisp/config/utils.py", line 365, in instantiate
    instance = instantiate(config, **overriden_args)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra_zen/_hydra_overloads.py", line 215, in instantiate
    return hydra_instantiate(config, *args, **kwargs)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 226, in instantiate
    return instantiate_node(
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 347, in instantiate_node
    return _call_target(_target_, partial, args, kwargs, full_key)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 97, in _call_target
    raise InstantiationException(msg) from e
hydra.errors.InstantiationException: Error in call to target 'hydra_zen.funcs.zen_processing':
IndexError('max(): Expected reduction dim 0 to have non-zero size.')

My sdf_selector is the following:

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.

import os
import numpy as np
import torch
import imgui
from wisp.framework.state import WispState
from wisp.models import Pipeline
from wisp.renderer.gui import WidgetImgui
from wisp.renderer.core.api import add_to_scene_graph, remove_from_scene_graph
from sdf_app_config import SDFAppConfig

from wisp.config import instantiate



class WidgetSDFSelector(WidgetImgui):
    """ A custom widget which lets users browse SPC files, and select them to populate the scene graph. """

    def __init__(self, app_config: SDFAppConfig):
        self.curr_file_idx = 0
        self.inited = False
        self.cfg = app_config

    def create_pipeline(self, filename, device):

        if self.cfg.pretrained and self.cfg.trainer.model_format == "full":
            pipeline = torch.load(self.cfg.pretrained)  # Load a full pretrained pipeline: model + weights
        else:  # Create model from scratch
            # blas is the occupancy acceleration structure, possibly initialized sparsely from a mesh
            blas = instantiate(self.cfg.blas)
            grid = instantiate(self.cfg.grid, blas=blas)  # A grid keeps track of both features and occupancy
            nef = instantiate(self.cfg.nef, grid=grid)  # nef here is a SDF which uses a grid as the backbone
            # tracer here is used to efficiently render the SDF at test time (note: not used during optimization).
            # Wisp's implementation of Neural Geometric LOD uses PackedSDFTracer to trace the neural field:
            # - Packed: each ray yields a custom number of sphere tracing steps,
            #   which are therefore packed in a flat form within a tensor,
            #   see: https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#packed
            # - SDF: Signed Distance Function
            tracer = instantiate(self.cfg.tracer)
            pipeline = Pipeline(nef=nef, tracer=tracer)  # Binds neural field and tracer together to a single SDF callable
            if self.cfg.pretrained and self.cfg.trainer.model_format == "state_dict":
                pipeline.load_state_dict(torch.load(self.cfg.pretrained))

        return pipeline

    def paint(self, state: WispState, *args, **kwargs):
        """ Paint will be automatically called by the gui.
            Each widget included by the BrowseSPCApp will be automatically painted.
        """
        expanded, _ = imgui.collapsing_header("Object Browser", visible=True, flags=imgui.TREE_NODE_DEFAULT_OPEN)
        if expanded:
            available_files = state.extent['available_files']
            if len(available_files) == 0:
                if not self.inited:
                    print('Warning: No OBJ files were found in input folder!')
                    self.inited = True
                return
            file_names = []

            for fpath in available_files:
                fnameext = os.path.basename(fpath)
                fname = os.path.splitext(fnameext)[0]
                file_names.append(fname)

            is_clicked, selected_file_idx = imgui.combo("Filename", self.curr_file_idx,
                                                        file_names)  # display your choices here
            if is_clicked or not self.inited:
                old_file_name = file_names[self.curr_file_idx]
                new_file_name = file_names[selected_file_idx]

                # Add new object to the scene graph
                # This will toggle a "marked as dirty" flag,
                # which forcees the render core to refresh the scene graph and load the new object
                device = state.renderer.device
                sdf_pipeline = self.create_pipeline(available_files[selected_file_idx], device=device)
                add_to_scene_graph(state, name=new_file_name, obj=sdf_pipeline)

                # Remove old object from scene graph, if it exists
                if old_file_name != new_file_name and old_file_name in state.graph.neural_pipelines:
                    remove_from_scene_graph(state, name=old_file_name)

                self.curr_file_idx = selected_file_idx
            self.inited = True

Attached I have the full demo source code, just in case it makes of any help. Just put it inside the app folder similarly to the nerf and nglod folders and then just run the main_sdf_demo.py file

sdf_demo.zip

@ttsesm
Copy link
Author

ttsesm commented Jul 7, 2023

ok, specifying the .obj file from where to create the octree seems to do the job:

    def create_pipeline(self, filename, device):

        if self.cfg.pretrained and self.cfg.trainer.model_format == "full":
            pipeline = torch.load(self.cfg.pretrained)  # Load a full pretrained pipeline: model + weights
        else:  # Create model from scratch
            # blas is the occupancy acceleration structure, possibly initialized sparsely from a mesh
            self.cfg.blas.mesh_path = filename # specify .obj file to load <----------------------------------------------------------
            blas = instantiate(self.cfg.blas)
            grid = instantiate(self.cfg.grid, blas=blas)  # A grid keeps track of both features and occupancy
            nef = instantiate(self.cfg.nef, grid=grid)  # nef here is a SDF which uses a grid as the backbone
            # tracer here is used to efficiently render the SDF at test time (note: not used during optimization).
            # Wisp's implementation of Neural Geometric LOD uses PackedSDFTracer to trace the neural field:
            # - Packed: each ray yields a custom number of sphere tracing steps,
            #   which are therefore packed in a flat form within a tensor,
            #   see: https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#packed
            # - SDF: Signed Distance Function
            tracer = instantiate(self.cfg.tracer)
            pipeline = Pipeline(nef=nef, tracer=tracer)  # Binds neural field and tracer together to a single SDF callable
            if self.cfg.pretrained and self.cfg.trainer.model_format == "state_dict":
                pipeline.load_state_dict(torch.load(self.cfg.pretrained))

        return pipeline

but now I am getting the following:

2023-07-07 18:32:54,273|    INFO| Active LODs: [2, 3, 4, 5, 6, 7]
2023-07-07 18:32:54,342|    INFO| Built dual octree and trinkets
2023-07-07 18:32:54,342|    INFO| # Feature Vectors: 133843
2023-07-07 18:32:54,354|    INFO| Pyramid:tensor([    1,     8,    34,   146,   684,  3180, 12730, 36573,     0],
       dtype=torch.int32)
2023-07-07 18:32:54,354|    INFO| Pyramid Dual: tensor([     8,     27,     84,    269,   1104,   5451,  25265, 101664,      0],
       dtype=torch.int32)
/Development/spc/venv_spc/lib/python3.9/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "/Development/spc/app/sdf_demo/main_sdf_demo.py", line 72, in <module>
    app.run()
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 260, in run
    app.run()   # App clock should always run as frequently as possible (background tasks should not be limited)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 362, in run
    run(duration, framecount)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/__init__.py", line 344, in run
    count = __backend__.process(dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/backends/backend_glfw_imgui.py", line 448, in process
    window.dispatch_event('on_draw', dt)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/glumpy/app/window/event.py", line 396, in dispatch_event
    if getattr(self, event_type)(*args):
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 550, in on_draw
    self.render()     # Render objects uploaded to GPU
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 36, in _enable_amp
    return func(self, *args, **kwargs)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 518, in render
    img, depth_img = self.render_canvas(self.render_core, dt, self.canvas_dirty)
  File "/Development/spc/wisp/renderer/app/wisp_app.py", line 407, in render_canvas
    renderbuffer = render_core.render(time_delta, force_render)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 30, in _enable_amp
    return func(self, *args, **kwargs)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 222, in render
    rb = self._render_payload(payload, force_render)
  File "/Development/spc/wisp/renderer/core/render_core.py", line 321, in _render_payload
    rb = renderer.render(in_rays)
  File "/Development/spc/wisp/renderer/core/api/raytraced_renderer.py", line 123, in render
    rb = self.tracer(self.nef, rays=rays, channels=self.channels)
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Development/spc/wisp/tracers/base_tracer.py", line 158, in forward
    rb = self.trace(nef, rays, requested_channels, requested_extra_channels, **input_args)
  File "/Development/spc/wisp/tracers/packed_sdf_tracer.py", line 86, in trace
    raytrace_results = nef.grid.raytrace(rays, nef.grid.active_lods[lod_idx], with_exit=True)
  File "/Development/spc/wisp/models/grids/blas_grid.py", line 45, in raytrace
    return self.blas.raytrace(*args, **kwargs)
  File "/Development/spc/wisp/accelstructs/octree_as.py", line 181, in raytrace
    ridx, pidx, depth = spc_render.unbatched_raytrace(
  File "/Development/spc/venv_spc/lib/python3.9/site-packages/kaolin/render/spc/raytrace.py", line 67, in unbatched_raytrace
    output = _C.render.spc.raytrace_cuda(
RuntimeError: Tensor for argument #5 'ray_o' is on CPU, but expected it to be on GPU (while checking arguments for raytrace_cuda)

so I am not sure if using paint() is the correct way to add the objects in the scene_graph 🤔

@ttsesm
Copy link
Author

ttsesm commented Jul 10, 2023

Ok small win, the above error is resolved by passing the pipeline to the cuda device since by default the device was set to the cpu.

    def create_pipeline(self, filename, device):

        if self.cfg.pretrained and self.cfg.trainer.model_format == "full":
            pipeline = torch.load(self.cfg.pretrained)  # Load a full pretrained pipeline: model + weights
        else:  # Create model from scratch
            # blas is the occupancy acceleration structure, possibly initialized sparsely from a mesh
            self.cfg.blas.mesh_path = filename # specify .obj file to load <----------------------------------------------------------
            blas = instantiate(self.cfg.blas)
            grid = instantiate(self.cfg.grid, blas=blas)  # A grid keeps track of both features and occupancy
            nef = instantiate(self.cfg.nef, grid=grid)  # nef here is a SDF which uses a grid as the backbone
            # tracer here is used to efficiently render the SDF at test time (note: not used during optimization).
            # Wisp's implementation of Neural Geometric LOD uses PackedSDFTracer to trace the neural field:
            # - Packed: each ray yields a custom number of sphere tracing steps,
            #   which are therefore packed in a flat form within a tensor,
            #   see: https://kaolin.readthedocs.io/en/latest/modules/kaolin.ops.batch.html#packed
            # - SDF: Signed Distance Function
            tracer = instantiate(self.cfg.tracer)
            pipeline = Pipeline(nef=nef, tracer=tracer)  # Binds neural field and tracer together to a single SDF callable
            if self.cfg.pretrained and self.cfg.trainer.model_format == "state_dict":
                pipeline.load_state_dict(torch.load(self.cfg.pretrained))

        return pipeline.to(device)

image

@orperel I am not sure now how to connect the WidgetOptimization() with the tracker and the trainer for each one of the objects. Should I add them into the paint() function of the widget_sdf_selectoror into mydemo_app` file 🤔

@ttsesm
Copy link
Author

ttsesm commented Jul 10, 2023

The above sdf_browser works if you want to create an individual pipeline for each object each time. Thus, I've also tried to load multiple pipelines() where I am trying to render and compute the sdf for multiple objects at a time. I ended getting again the error with the tensor size:

c2_weight = (torch.sin((1.0 - t) * omega) / sin_omega).unsqueeze(1)
RuntimeError: The size of tensor a (1218) must match the size of tensor b (1230) at non-singleton dimension 1

which my guess is due to the different size of vertices/faces of my actual mesh files. Thus, now I am stuck on two main issues:

  1. How to add multiple pipelines of different size meshes at the same time
  2. How/where to create the tracker/trainer for each of the pipelines or the multiple ones

I've attached also the updated widget_sdf_selector.py script in case that someone wants to test it.
widget_sdf_selector.zip

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant