-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implements the `ImplicitRenderer` and `VolumeRenderer`. Reviewed By: gkioxari Differential Revision: D24418791 fbshipit-source-id: 127f21186d8e210895db1dcd0681f09f230d81a4
- Loading branch information
1 parent
e6a32bf
commit b466c38
Showing
8 changed files
with
1,575 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,372 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
from typing import Callable, Tuple | ||
|
||
import torch | ||
|
||
from ...ops.utils import eyes | ||
from ...structures import Volumes | ||
from ...transforms import Transform3d | ||
from ..cameras import CamerasBase | ||
from .raysampling import RayBundle | ||
from .utils import _validate_ray_bundle_variables, ray_bundle_variables_to_ray_points | ||
|
||
|
||
# The implicit renderer class should be initialized with a | ||
# function for raysampling and a function for raymarching. | ||
|
||
# During the forward pass: | ||
# 1) The raysampler: | ||
# - samples rays from input cameras | ||
# - transforms the rays to world coordinates | ||
# 2) The volumetric_function (which is a callable argument of the forwad pass) | ||
# evaluates ray_densities and ray_features at the sampled ray-points. | ||
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching | ||
# algorithm to render each ray. | ||
|
||
|
||
class ImplicitRenderer(torch.nn.Module): | ||
""" | ||
A class for rendering a batch of implicit surfaces. The class should | ||
be initialized with a raysampler and raymarcher class which both have | ||
to be a `Callable`. | ||
VOLUMETRIC_FUNCTION | ||
The `forward` function of the renderer accepts as input the rendering cameras as well | ||
as the `volumetric_function` `Callable`, which defines a field of opacity | ||
and feature vectors over the 3D domain of the scene. | ||
A standard `volumetric_function` has the following signature: | ||
``` | ||
def volumetric_function(ray_bundle: RayBundle) -> Tuple[torch.Tensor, torch.Tensor] | ||
``` | ||
With the following arguments: | ||
`ray_bundle`: A RayBundle object containing the following variables: | ||
`rays_origins`: A tensor of shape `(minibatch, ..., 3)` denoting | ||
the origins of the rendering rays. | ||
`rays_directions`: A tensor of shape `(minibatch, ..., 3)` | ||
containing the direction vectors of rendering rays. | ||
`rays_lengths`: A tensor of shape | ||
`(minibatch, ..., num_points_per_ray)`containing the | ||
lengths at which the ray points are sampled. | ||
Calling `volumetric_function` then returns the following: | ||
`rays_densities`: A tensor of shape | ||
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing | ||
the an opacity vector for each ray point. | ||
`rays_features`: A tensor of shape | ||
`(minibatch, ..., num_points_per_ray, feature_dim)` containing | ||
the an feature vector for each ray point. | ||
Example: | ||
A simple volumetric function of a 0-centered | ||
RGB sphere with a unit diameter is defined as follows: | ||
``` | ||
def volumetric_function( | ||
ray_bundle: RayBundle, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
# first convert the ray origins, directions and lengths | ||
# to 3D ray point locations in world coords | ||
rays_points_world = ray_bundle_to_ray_points(ray_bundle) | ||
# set the densities as an inverse sigmoid of the | ||
# ray point distance from the sphere centroid | ||
rays_densities = torch.sigmoid( | ||
-100.0 * rays_points_world.norm(dim=-1, keepdim=True) | ||
) | ||
# set the ray features to RGB colors proportional | ||
# to the 3D location of the projection of ray points | ||
# on the sphere surface | ||
rays_features = torch.nn.functional.normalize( | ||
rays_points_world, dim=-1 | ||
) * 0.5 + 0.5 | ||
return rays_densities, rays_features | ||
``` | ||
""" | ||
|
||
def __init__(self, raysampler: Callable, raymarcher: Callable): | ||
""" | ||
Args: | ||
raysampler: A `Callable` that takes as input scene cameras | ||
(an instance of `CamerasBase`) and returns a `RayBundle` that | ||
describes the rays emitted from the cameras. | ||
raymarcher: A `Callable` that receives the response of the | ||
`volumetric_function` (an input to `self.forward`) evaluated | ||
along the sampled rays, and renders the rays with a | ||
ray-marching algorithm. | ||
""" | ||
super().__init__() | ||
|
||
if not callable(raysampler): | ||
raise ValueError('"raysampler" has to be a "Callable" object.') | ||
if not callable(raymarcher): | ||
raise ValueError('"raymarcher" has to be a "Callable" object.') | ||
|
||
self.raysampler = raysampler | ||
self.raymarcher = raymarcher | ||
|
||
def forward( | ||
self, cameras: CamerasBase, volumetric_function: Callable, **kwargs | ||
) -> Tuple[torch.Tensor, RayBundle]: | ||
""" | ||
Render a batch of images using a volumetric function | ||
represented as a callable (e.g. a Pytorch module). | ||
Args: | ||
cameras: A batch of cameras that render the scene. A `self.raysampler` | ||
takes the cameras as input and samples rays that pass through the | ||
domain of the volumentric function. | ||
volumetric_function: A `Callable` that accepts the parametrizations | ||
of the rendering rays and returns the densities and features | ||
at the respective 3D of the rendering rays. Please refer to | ||
the main class documentation for details. | ||
Returns: | ||
images: A tensor of shape `(minibatch, ..., feature_dim + opacity_dim)` | ||
containing the result of the rendering. | ||
ray_bundle: A `RayBundle` containing the parametrizations of the | ||
sampled rendering rays. | ||
""" | ||
|
||
if not callable(volumetric_function): | ||
raise ValueError('"volumetric_function" has to be a "Callable" object.') | ||
|
||
# first call the ray sampler that returns the RayBundle parametrizing | ||
# the rendering rays. | ||
ray_bundle = self.raysampler( | ||
cameras=cameras, volumetric_function=volumetric_function, **kwargs | ||
) | ||
# ray_bundle.origins - minibatch x ... x 3 | ||
# ray_bundle.directions - minibatch x ... x 3 | ||
# ray_bundle.lengths - minibatch x ... x n_pts_per_ray | ||
# ray_bundle.xys - minibatch x ... x 2 | ||
|
||
# given sampled rays, call the volumetric function that | ||
# evaluates the densities and features at the locations of the | ||
# ray points | ||
rays_densities, rays_features = volumetric_function( | ||
ray_bundle=ray_bundle, cameras=cameras, **kwargs | ||
) | ||
# ray_densities - minibatch x ... x n_pts_per_ray x density_dim | ||
# ray_features - minibatch x ... x n_pts_per_ray x feature_dim | ||
|
||
# finally, march along the sampled rays to obtain the renders | ||
images = self.raymarcher( | ||
rays_densities=rays_densities, | ||
rays_features=rays_features, | ||
ray_bundle=ray_bundle, | ||
**kwargs | ||
) | ||
# images - minibatch x ... x (feature_dim + opacity_dim) | ||
|
||
return images, ray_bundle | ||
|
||
|
||
# The volume renderer class should be initialized with a | ||
# function for raysampling and a function for raymarching. | ||
|
||
# During the forward pass: | ||
# 1) The raysampler: | ||
# - samples rays from input cameras | ||
# - transforms the rays to world coordinates | ||
# 2) The scene volumes (which are an argument of the forward function) | ||
# are then sampled at the locations of the ray-points to generate | ||
# ray_densities and ray_features. | ||
# 3) The raymarcher takes ray_densities and ray_features and uses a raymarching | ||
# algorithm to render each ray. | ||
|
||
|
||
class VolumeRenderer(torch.nn.Module): | ||
""" | ||
A class for rendering a batch of Volumes. The class should | ||
be initialized with a raysampler and a raymarcher class which both have | ||
to be a `Callable`. | ||
""" | ||
|
||
def __init__( | ||
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear" | ||
): | ||
""" | ||
Args: | ||
raysampler: A `Callable` that takes as input scene cameras | ||
(an instance of `CamerasBase`) and returns a `RayBundle` that | ||
describes the rays emitted from the cameras. | ||
raymarcher: A `Callable` that receives the `volumes` | ||
(an instance of `Volumes` input to `self.forward`) | ||
sampled at the ray-points, and renders the rays with a | ||
ray-marching algorithm. | ||
sample_mode: Defines the algorithm used to sample the volumetric | ||
voxel grid. Can be either "bilinear" or "nearest". | ||
""" | ||
super().__init__() | ||
|
||
self.renderer = ImplicitRenderer(raysampler, raymarcher) | ||
self._sample_mode = sample_mode | ||
|
||
def forward( | ||
self, cameras: CamerasBase, volumes: Volumes, **kwargs | ||
) -> Tuple[torch.Tensor, RayBundle]: | ||
""" | ||
Render a batch of images using raymarching over rays cast through | ||
input `Volumes`. | ||
Args: | ||
cameras: A batch of cameras that render the scene. A `self.raysampler` | ||
takes the cameras as input and samples rays that pass through the | ||
domain of the volumentric function. | ||
volumes: An instance of the `Volumes` class representing a | ||
batch of volumes that are being rendered. | ||
Returns: | ||
images: A tensor of shape `(minibatch, ..., (feature_dim + opacity_dim)` | ||
containing the result of the rendering. | ||
ray_bundle: A `RayBundle` containing the parametrizations of the | ||
sampled rendering rays. | ||
""" | ||
volumetric_function = VolumeSampler(volumes, sample_mode=self._sample_mode) | ||
return self.renderer( | ||
cameras=cameras, volumetric_function=volumetric_function, **kwargs | ||
) | ||
|
||
|
||
class VolumeSampler(torch.nn.Module): | ||
""" | ||
A class that allows to sample a batch of volumes `Volumes` | ||
at 3D points sampled along projection rays. | ||
""" | ||
|
||
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"): | ||
""" | ||
Args: | ||
volumes: An instance of the `Volumes` class representing a | ||
batch if volumes that are being rendered. | ||
sample_mode: Defines the algorithm used to sample the volumetric | ||
voxel grid. Can be either "bilinear" or "nearest". | ||
""" | ||
super().__init__() | ||
if not isinstance(volumes, Volumes): | ||
raise ValueError("'volumes' have to be an instance of the 'Volumes' class.") | ||
self._volumes = volumes | ||
self._sample_mode = sample_mode | ||
|
||
def _get_ray_directions_transform(self): | ||
""" | ||
Compose the ray-directions transform by removing the translation component | ||
from the volume global-to-local coords transform. | ||
""" | ||
world2local = self._volumes.get_world_to_local_coords_transform().get_matrix() | ||
directions_transform_matrix = eyes( | ||
4, | ||
N=world2local.shape[0], | ||
device=world2local.device, | ||
dtype=world2local.dtype, | ||
) | ||
directions_transform_matrix[:, :3, :3] = world2local[:, :3, :3] | ||
directions_transform = Transform3d(matrix=directions_transform_matrix) | ||
return directions_transform | ||
|
||
def forward( | ||
self, ray_bundle: RayBundle, **kwargs | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
""" | ||
Given an input ray parametrization, the forward function samples | ||
`self._volumes` at the respective 3D ray-points. | ||
Args: | ||
ray_bundle: A RayBundle object with the following fields: | ||
rays_origins_world: A tensor of shape `(minibatch, ..., 3)` denoting the | ||
origins of the sampling rays in world coords. | ||
rays_directions_world: A tensor of shape `(minibatch, ..., 3)` | ||
containing the direction vectors of sampling rays in world coords. | ||
rays_lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` | ||
containing the lengths at which the rays are sampled. | ||
Returns: | ||
rays_densities: A tensor of shape | ||
`(minibatch, ..., num_points_per_ray, opacity_dim)` containing the | ||
densitity vectors sampled from the volume at the locations of | ||
the ray points. | ||
rays_features: A tensor of shape | ||
`(minibatch, ..., num_points_per_ray, feature_dim)` containing the | ||
feature vectors sampled from the volume at the locations of | ||
the ray points. | ||
""" | ||
|
||
# take out the interesting parts of ray_bundle | ||
rays_origins_world = ray_bundle.origins | ||
rays_directions_world = ray_bundle.directions | ||
rays_lengths = ray_bundle.lengths | ||
|
||
# validate the inputs | ||
_validate_ray_bundle_variables( | ||
rays_origins_world, rays_directions_world, rays_lengths | ||
) | ||
if self._volumes.densities().shape[0] != rays_origins_world.shape[0]: | ||
raise ValueError("Input volumes have to have the same batch size as rays.") | ||
|
||
######################################################### | ||
# 1) convert the origins/directions to the local coords # | ||
######################################################### | ||
|
||
# origins are mapped with the world_to_local transform of the volumes | ||
rays_origins_local = self._volumes.world_to_local_coords(rays_origins_world) | ||
|
||
# obtain the Transform3d object that transforms ray directions to local coords | ||
directions_transform = self._get_ray_directions_transform() | ||
|
||
# transform the directions to the local coords | ||
rays_directions_local = directions_transform.transform_points( | ||
rays_directions_world.view(rays_lengths.shape[0], -1, 3) | ||
).view(rays_directions_world.shape) | ||
|
||
############################ | ||
# 2) obtain the ray points # | ||
############################ | ||
|
||
# this op produces a fairly big tensor (minibatch, ..., n_samples_per_ray, 3) | ||
rays_points_local = ray_bundle_variables_to_ray_points( | ||
rays_origins_local, rays_directions_local, rays_lengths | ||
) | ||
|
||
######################## | ||
# 3) sample the volume # | ||
######################## | ||
|
||
# generate the tensor for sampling | ||
volumes_densities = self._volumes.densities() | ||
dim_density = volumes_densities.shape[1] | ||
volumes_features = self._volumes.features() | ||
# adjust the volumes_features variable in case we have a feature-less volume | ||
if volumes_features is None: | ||
dim_feature = 0 | ||
data_to_sample = volumes_densities | ||
else: | ||
dim_feature = volumes_features.shape[1] | ||
data_to_sample = torch.cat((volumes_densities, volumes_features), dim=1) | ||
|
||
# reshape to a size which grid_sample likes | ||
rays_points_local_flat = rays_points_local.view( | ||
rays_points_local.shape[0], -1, 1, 1, 3 | ||
) | ||
|
||
# run the grid sampler | ||
data_sampled = torch.nn.functional.grid_sample( | ||
data_to_sample, | ||
rays_points_local_flat, | ||
align_corners=True, | ||
mode=self._sample_mode, | ||
) | ||
|
||
# permute the dimensions & reshape after sampling | ||
data_sampled = data_sampled.permute(0, 2, 3, 4, 1).view( | ||
*rays_points_local.shape[:-1], data_sampled.shape[1] | ||
) | ||
|
||
# split back to densities and features | ||
rays_densities, rays_features = data_sampled.split( | ||
[dim_density, dim_feature], dim=-1 | ||
) | ||
|
||
return rays_densities, rays_features |
Oops, something went wrong.