From 9fc661f8b325142323b1925109ccf87ebb5904f2 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Wed, 6 Jan 2021 08:24:57 -0800 Subject: [PATCH] Tutorial - Fit neural radiance field Summary: Implements a simple nerf tutorial. Reviewed By: nikhilaravi Differential Revision: D24650983 fbshipit-source-id: b3db51c0ed74779ec9b510350d1675b0ae89422c --- .../fit_simple_neural_radiance_field.ipynb | 861 ++++++++++++++++++ docs/tutorials/utils/generate_cow_renders.py | 7 +- 2 files changed, 866 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/fit_simple_neural_radiance_field.ipynb diff --git a/docs/tutorials/fit_simple_neural_radiance_field.ipynb b/docs/tutorials/fit_simple_neural_radiance_field.ipynb new file mode 100644 index 000000000..6546920dd --- /dev/null +++ b/docs/tutorials/fit_simple_neural_radiance_field.ipynb @@ -0,0 +1,861 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fit a simple Neural Radiance Field via raymarching\n", + "\n", + "This tutorial shows how to fit Neural Radiance Field given a set of views of a scene using differentiable implicit function rendering.\n", + "\n", + "More specificially, this tutorial will explain how to:\n", + "1. Create a differentiable implicit function renderer with either image-grid or Monte Carlo ray sampling.\n", + "2. Create an Implicit model of a scene.\n", + "3. Fit the implicit function (Neural Radiance Field) based on input images using the differentiable implicit renderer. \n", + "4. Visualize the learnt implicit function.\n", + "\n", + "Note that the presented implicit model is a simplified version of NeRF:
\n", + "_Ben Mildenhall, Pratul P. Srinivasan, Matthew Tancik, Jonathan T. Barron, Ravi Ramamoorthi, Ren Ng: NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis, ECCV 2020._\n", + "\n", + "The simplifications include:\n", + "* *Ray sampling*: This notebook does not perform stratified ray sampling but rather ray sampling at equidistant depths.\n", + "* *Rendering*: We do a single rendering pass, as opposed to the original implementation that does a coarse and fine rendering pass.\n", + "* *Architecture*: Our network is shallower which allows for faster optimization possibly at the cost of surface details.\n", + "* *Mask loss*: Since our observations include segmentation masks, we also optimize a silhouette loss that forces rays to either get fully absorbed inside the volume, or to completely pass through it.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0. Install and Import modules\n", + "If `torch` and `pytorch3d` are not installed, run the following cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install torch\n", + "# import sys\n", + "# import torch\n", + "# if torch.__version__=='1.6.0+cu101' and sys.platform.startswith('linux'):\n", + "# !pip install pytorch3d\n", + "# else:\n", + "# !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# %matplotlib inline\n", + "# %matplotlib notebook\n", + "import os\n", + "import sys\n", + "import time\n", + "import json\n", + "import glob\n", + "import torch\n", + "import math\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "from IPython import display\n", + "from tqdm.notebook import tqdm\n", + "\n", + "# Data structures and functions for rendering\n", + "from pytorch3d.structures import Volumes\n", + "from pytorch3d.transforms import so3_exponential_map\n", + "from pytorch3d.renderer import (\n", + " FoVPerspectiveCameras, \n", + " NDCGridRaysampler,\n", + " MonteCarloRaysampler,\n", + " EmissionAbsorptionRaymarcher,\n", + " ImplicitRenderer,\n", + " RayBundle,\n", + " ray_bundle_to_ray_points,\n", + ")\n", + "\n", + "# add path for demo utils functions \n", + "sys.path.append(os.path.abspath(''))\n", + "from utils.plot_image_grid import image_grid\n", + "from utils.generate_cow_renders import generate_cow_renders\n", + "\n", + "# obtain the utilized device\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda:0\")\n", + " torch.cuda.set_device(device)\n", + "else:\n", + " print(\n", + " 'Please note that NeRF is a resource-demanding method.'\n", + " + ' Running this notebook on CPU will be extremely slow.'\n", + " + ' We recommend running the example on a GPU'\n", + " + ' with at least 10 GB of memory.'\n", + " )\n", + " device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Generate images of the scene and masks\n", + "\n", + "The following cell generates our training data.\n", + "It renders the cow mesh from the `fit_textured_mesh.ipynb` tutorial from several viewpoints and returns:\n", + "1. A batch of image and silhouette tensors that are produced by the cow mesh renderer.\n", + "2. A set of cameras corresponding to each render.\n", + "\n", + "Note: For the purpose of this tutorial, which aims at explaining the details of implicit rendering, we do not explain how the mesh rendering, implemented in the `generate_cow_renders` function, works. Please refer to `fit_textured_mesh.ipynb` for a detailed explanation of mesh rendering." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_cameras, target_images, target_silhouettes = generate_cow_renders(num_views=40, azimuth_range=180)\n", + "print(f'Generated {len(target_images)} images/silhouettes/cameras.')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Initialize the implicit renderer\n", + "\n", + "The following initializes an implicit renderer that emits a ray from each pixel of a target image and samples a set of uniformly-spaced points along the ray. At each ray-point, the corresponding density and color value is obtained by querying the corresponding location in the neural model of the scene (the model is described & instantiated in a later cell).\n", + "\n", + "The renderer is composed of a *raymarcher* and a *raysampler*.\n", + "- The *raysampler* is responsible for emiting rays from image pixels and sampling the points along them. Here, we use two different raysamplers:\n", + " - `MonteCarloRaysampler` is used to generate rays from a random subset of pixels of the image plane. The random subsampling of pixels is carried out during **training** to decrease the memory consumption of the implicit model.\n", + " - `NDCGridRaysampler` which follows the standard PyTorch3d coordinate grid convention (+X from right to left; +Y from bottom to top; +Z away from the user). In combination with the implicit model of the scene, `NDCGridRaysampler` consumes a large amount of memory and, hence, is only used for visualizing the results of the training at **test** time.\n", + "- The *raymarcher* takes the densities and colors sampled along each ray and renders each ray into a color and an opacity value of the ray's source pixel. Here we use the `EmissionAbsorptionRaymarcher` which implements the standard Emission-Absorption raymarching algorithm." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# render_size describes the size of both sides of the \n", + "# rendered images in pixels. Since an advantage of \n", + "# Neural Radiance Fields are high quality renders\n", + "# with a significant amount of details, we render\n", + "# the implicit function at double the size of \n", + "# target images.\n", + "render_size = target_images.shape[1] * 2\n", + "\n", + "# Our rendered scene is centered around (0,0,0) \n", + "# and is enclosed inside a bounding box\n", + "# whose side is roughly equal to 3.0 (world units).\n", + "volume_extent_world = 3.0\n", + "\n", + "# 1) Instantiate the raysamplers.\n", + "\n", + "# Here, NDCGridRaysampler generates a rectangular image\n", + "# grid of rays whose coordinates follow the PyTorch3d\n", + "# coordinate conventions.\n", + "raysampler_grid = NDCGridRaysampler(\n", + " image_height=render_size,\n", + " image_width=render_size,\n", + " n_pts_per_ray=128,\n", + " min_depth=0.1,\n", + " max_depth=volume_extent_world,\n", + ")\n", + "\n", + "# MonteCarloRaysampler generates a random subset \n", + "# of `n_rays_per_image` rays emitted from the image plane.\n", + "raysampler_mc = MonteCarloRaysampler(\n", + " min_x = -1.0,\n", + " max_x = 1.0,\n", + " min_y = -1.0,\n", + " max_y = 1.0,\n", + " n_rays_per_image=750,\n", + " n_pts_per_ray=128,\n", + " min_depth=0.1,\n", + " max_depth=volume_extent_world,\n", + ")\n", + "\n", + "# 2) Instantiate the raymarcher.\n", + "# Here, we use the standard EmissionAbsorptionRaymarcher \n", + "# which marches along each ray in order to render\n", + "# the ray into a single 3D color vector \n", + "# and an opacity scalar.\n", + "raymarcher = EmissionAbsorptionRaymarcher()\n", + "\n", + "# Finally, instantiate the implicit renders\n", + "# for both raysamplers.\n", + "renderer_grid = ImplicitRenderer(\n", + " raysampler=raysampler_grid, raymarcher=raymarcher,\n", + ")\n", + "renderer_mc = ImplicitRenderer(\n", + " raysampler=raysampler_mc, raymarcher=raymarcher,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Define the neural radiance field model\n", + "\n", + "In this cell we define the `NeuralRadianceField` module, which specifies a continuous field of colors and opacities over the 3D domain of the scene.\n", + "\n", + "The `forward` function of `NeuralRadianceField` (NeRF) receives as input a set of tensors that parametrize a bundle of rendering rays. The ray bundle is later converted to 3D ray points in the world coordinates of the scene. Each 3D point is then mapped to a harmonic representation using the `HarmonicEmbedding` layer (defined in the next cell). The harmonic embeddings then enter the _color_ and _opacity_ branches of the NeRF model in order to label each ray point with a 3D vector and a 1D scalar ranging in [0-1] which define the point's RGB color and opacity respectively.\n", + "\n", + "Since NeRF has a large memory footprint, we also implement the `NeuralRadianceField.forward_batched` method. The method splits the input rays into batches and executes the `forward` function for each batch separately in a for loop. This allows to render a large set of rays without running out of GPU memory. Standardly, `forward_batched` would be used to render rays emitted from all pixels of an image in order to produce a full-sized render of a scene.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class HarmonicEmbedding(torch.nn.Module):\n", + " def __init__(self, n_harmonic_functions=60, omega0=0.1):\n", + " \"\"\"\n", + " Given an input tensor `x` of shape [minibatch, ... , dim],\n", + " the harmonic embedding layer converts each feature\n", + " in `x` into a series of harmonic features `embedding`\n", + " as follows:\n", + " embedding[..., i*dim:(i+1)*dim] = [\n", + " sin(x[..., i]),\n", + " sin(2*x[..., i]),\n", + " sin(4*x[..., i]),\n", + " ...\n", + " sin(2**self.n_harmonic_functions * x[..., i]),\n", + " cos(x[..., i]),\n", + " cos(2*x[..., i]),\n", + " cos(4*x[..., i]),\n", + " ...\n", + " cos(2**self.n_harmonic_functions * x[..., i])\n", + " ]\n", + " \n", + " Note that `x` is also premultiplied by `omega0` before\n", + " evaluting the harmonic functions.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.register_buffer(\n", + " 'frequencies',\n", + " omega0 * (2.0 ** torch.arange(n_harmonic_functions)),\n", + " )\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Args:\n", + " x: tensor of shape [..., dim]\n", + " Returns:\n", + " embedding: a harmonic embedding of `x`\n", + " of shape [..., n_harmonic_functions * dim * 2]\n", + " \"\"\"\n", + " embed = (x[..., None] * self.frequencies).view(*x.shape[:-1], -1)\n", + " return torch.cat((embed.sin(), embed.cos()), dim=-1)\n", + "\n", + "\n", + "class NeuralRadianceField(torch.nn.Module):\n", + " def __init__(self, n_harmonic_functions=60, n_hidden_neurons=256):\n", + " super().__init__()\n", + " \"\"\"\n", + " Args:\n", + " n_harmonic_functions: The number of harmonic functions\n", + " used to form the harmonic embedding of each point.\n", + " n_hidden_neurons: The number of hidden units in the\n", + " fully connected layers of the MLPs of the model.\n", + " \"\"\"\n", + " \n", + " # The harmonic embedding layer converts input 3D coordinates\n", + " # to a representation that is more suitable for\n", + " # processing with a deep neural network.\n", + " self.harmonic_embedding = HarmonicEmbedding(n_harmonic_functions)\n", + " \n", + " # The dimension of the harmonic embedding.\n", + " embedding_dim = n_harmonic_functions * 2 * 3\n", + " \n", + " # self.mlp is a simple 2-layer multi-layer perceptron\n", + " # which converts the input per-point harmonic embeddings\n", + " # to a latent representation.\n", + " # Not that we use Softplus activations instead of ReLU.\n", + " self.mlp = torch.nn.Sequential(\n", + " torch.nn.Linear(embedding_dim, n_hidden_neurons),\n", + " torch.nn.Softplus(beta=10.0),\n", + " torch.nn.Linear(n_hidden_neurons, n_hidden_neurons),\n", + " torch.nn.Softplus(beta=10.0),\n", + " ) \n", + " \n", + " # Given features predicted by self.mlp, self.color_layer\n", + " # is responsible for predicting a 3-D per-point vector\n", + " # that represents the RGB color of the point.\n", + " self.color_layer = torch.nn.Sequential(\n", + " torch.nn.Linear(n_hidden_neurons + embedding_dim, n_hidden_neurons),\n", + " torch.nn.Softplus(beta=10.0),\n", + " torch.nn.Linear(n_hidden_neurons, 3),\n", + " torch.nn.Sigmoid(),\n", + " # To ensure that the colors correctly range between [0-1],\n", + " # the layer is terminated with a sigmoid layer.\n", + " ) \n", + " \n", + " # The density layer converts the features of self.mlp\n", + " # to a 1D density value representing the raw opacity\n", + " # of each point.\n", + " self.density_layer = torch.nn.Sequential(\n", + " torch.nn.Linear(n_hidden_neurons, 1),\n", + " torch.nn.Softplus(beta=10.0),\n", + " # Sofplus activation ensures that the raw opacity\n", + " # is a non-negative number.\n", + " )\n", + " \n", + " # We set the bias of the density layer to -1.5\n", + " # in order to initialize the opacities of the\n", + " # ray points to values close to 0. \n", + " # This is a crucial detail for ensuring convergence\n", + " # of the model.\n", + " self.density_layer[0].bias.data[0] = -1.5 \n", + " \n", + " def _get_densities(self, features):\n", + " \"\"\"\n", + " This function takes `features` predicted by `self.mlp`\n", + " and converts them to `raw_densities` with `self.density_layer`.\n", + " `raw_densities` are later mapped to [0-1] range with\n", + " 1 - inverse exponential of `raw_densities`.\n", + " \"\"\"\n", + " raw_densities = self.density_layer(features)\n", + " return 1 - (-raw_densities).exp()\n", + " \n", + " def _get_colors(self, features, rays_directions):\n", + " \"\"\"\n", + " This function takes per-point `features` predicted by `self.mlp`\n", + " and evaluates the color model in order to attach to each\n", + " point a 3D vector of its RGB color.\n", + " \n", + " In order to represent viewpoint dependent effects,\n", + " before evaluating `self.color_layer`, `NeuralRadianceField`\n", + " concatenates to the `features` a harmonic embedding\n", + " of `ray_directions`, which are per-point directions \n", + " of point rays expressed as 3D l2-normalized vectors\n", + " in world coordinates.\n", + " \"\"\"\n", + " spatial_size = features.shape[:-1]\n", + " \n", + " # Normalize the ray_directions to unit l2 norm.\n", + " rays_directions_normed = torch.nn.functional.normalize(\n", + " rays_directions, dim=-1\n", + " )\n", + " \n", + " # Obtain the harmonic embedding of the normalized ray directions.\n", + " rays_embedding = self.harmonic_embedding(\n", + " rays_directions_normed\n", + " )\n", + " \n", + " # Expand the ray directions tensor so that its spatial size\n", + " # is equal to the size of features.\n", + " rays_embedding_expand = rays_embedding[..., None, :].expand(\n", + " *spatial_size, rays_embedding.shape[-1]\n", + " )\n", + " \n", + " # Concatenate ray direction embeddings with \n", + " # features and evaluate the color model.\n", + " color_layer_input = torch.cat(\n", + " (features, rays_embedding_expand),\n", + " dim=-1\n", + " )\n", + " return self.color_layer(color_layer_input)\n", + " \n", + " \n", + " def forward(\n", + " self, \n", + " ray_bundle: RayBundle,\n", + " **kwargs,\n", + " ):\n", + " \"\"\"\n", + " The forward function accepts the parametrizations of\n", + " 3D points sampled along projection rays. The forward\n", + " pass is responsible for attaching a 3D vector\n", + " and a 1D scalar representing the point's \n", + " RGB color and opacity respectively.\n", + " \n", + " Args:\n", + " ray_bundle: A RayBundle object containing the following variables:\n", + " origins: A tensor of shape `(minibatch, ..., 3)` denoting the\n", + " origins of the sampling rays in world coords.\n", + " directions: A tensor of shape `(minibatch, ..., 3)`\n", + " containing the direction vectors of sampling rays in world coords.\n", + " lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`\n", + " containing the lengths at which the rays are sampled.\n", + "\n", + " Returns:\n", + " rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`\n", + " denoting the opacitiy of each ray point.\n", + " rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`\n", + " denoting the color of each ray point.\n", + " \"\"\"\n", + " # We first convert the ray parametrizations to world\n", + " # coordinates with `ray_bundle_to_ray_points`.\n", + " rays_points_world = ray_bundle_to_ray_points(ray_bundle)\n", + " # rays_points_world.shape = [minibatch x ... x 3]\n", + " \n", + " # For each 3D world coordinate, we obtain its harmonic embedding.\n", + " embeds = self.harmonic_embedding(\n", + " rays_points_world\n", + " )\n", + " # embeds.shape = [minibatch x ... x self.n_harmonic_functions*6]\n", + " \n", + " # self.mlp maps each harmonic embedding to a latent feature space.\n", + " features = self.mlp(embeds)\n", + " # features.shape = [minibatch x ... x n_hidden_neurons]\n", + " \n", + " # Finally, given the per-point features, \n", + " # execute the density and color branches.\n", + " \n", + " rays_densities = self._get_densities(features)\n", + " # rays_densities.shape = [minibatch x ... x 1]\n", + "\n", + " rays_colors = self._get_colors(features, ray_bundle.directions)\n", + " # rays_colors.shape = [minibatch x ... x 3]\n", + " \n", + " return rays_densities, rays_colors\n", + " \n", + " def batched_forward(\n", + " self, \n", + " ray_bundle: RayBundle,\n", + " n_batches: int = 16,\n", + " **kwargs, \n", + " ):\n", + " \"\"\"\n", + " This function is used to allow for memory efficient processing\n", + " of input rays. The input rays are first split to `n_batches`\n", + " chunks and passed through the `self.forward` function one at a time\n", + " in a for loop. Combined with disabling Pytorch gradient caching\n", + " (`torch.no_grad()`), this allows for rendering large batches\n", + " of rays that do not all fit into GPU memory in a single forward pass.\n", + " In our case, batched_forward is used to export a fully-sized render\n", + " of the radiance field for visualisation purposes.\n", + " \n", + " Args:\n", + " ray_bundle: A RayBundle object containing the following variables:\n", + " origins: A tensor of shape `(minibatch, ..., 3)` denoting the\n", + " origins of the sampling rays in world coords.\n", + " directions: A tensor of shape `(minibatch, ..., 3)`\n", + " containing the direction vectors of sampling rays in world coords.\n", + " lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`\n", + " containing the lengths at which the rays are sampled.\n", + " n_batches: Specifies the number of batches the input rays are split into.\n", + " The larger the number of batches, the smaller the memory footprint\n", + " and the lower the processing speed.\n", + "\n", + " Returns:\n", + " rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`\n", + " denoting the opacitiy of each ray point.\n", + " rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`\n", + " denoting the color of each ray point.\n", + "\n", + " \"\"\"\n", + "\n", + " # Parse out shapes needed for tensor reshaping in this function.\n", + " n_pts_per_ray = ray_bundle.lengths.shape[-1] \n", + " spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]\n", + "\n", + " # Split the rays to `n_batches` batches.\n", + " tot_samples = ray_bundle.origins.shape[:-1].numel()\n", + " batches = torch.chunk(torch.arange(tot_samples), n_batches)\n", + "\n", + " # For each batch, execute the standard forward pass.\n", + " batch_outputs = [\n", + " self.forward(\n", + " RayBundle(\n", + " origins=ray_bundle.origins.view(-1, 3)[batch_idx],\n", + " directions=ray_bundle.directions.view(-1, 3)[batch_idx],\n", + " lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],\n", + " xys=None,\n", + " )\n", + " ) for batch_idx in batches\n", + " ]\n", + " \n", + " # Concatenate the per-batch rays_densities and rays_colors\n", + " # and reshape according to the sizes of the inputs.\n", + " rays_densities, rays_colors = [\n", + " torch.cat(\n", + " [batch_output[output_i] for batch_output in batch_outputs], dim=0\n", + " ).view(*spatial_size, -1) for output_i in (0, 1)\n", + " ]\n", + " return rays_densities, rays_colors" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Helper functions\n", + "\n", + "In this function we define functions that help with the Neural Radiance Field optimization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def huber(x, y, scaling=0.1):\n", + " \"\"\"\n", + " A helper function for evaluating the smooth L1 (huber) loss\n", + " between the rendered silhouettes and colors.\n", + " \"\"\"\n", + " diff_sq = (x - y) ** 2\n", + " loss = ((1 + diff_sq / (scaling**2)).clamp(1e-4).sqrt() - 1) * float(scaling)\n", + " return loss\n", + "\n", + "def sample_images_at_mc_locs(target_images, sampled_rays_xy):\n", + " \"\"\"\n", + " Given a set of Monte Carlo pixel locations `sampled_rays_xy`,\n", + " this method samples the tensor `target_images` at the\n", + " respective 2D locations.\n", + " \n", + " This function is used in order to extract the colors from\n", + " ground truth images that correspond to the colors\n", + " rendered using `MonteCarloRaysampler`.\n", + " \"\"\"\n", + " ba = target_images.shape[0]\n", + " dim = target_images.shape[-1]\n", + " spatial_size = sampled_rays_xy.shape[1:-1]\n", + " # In order to sample target_images, we utilize\n", + " # the grid_sample function which implements a\n", + " # bilinear image sampler.\n", + " # Note that we have to invert the sign of the \n", + " # sampled ray positions to convert the NDC xy locations\n", + " # of the MonteCarloRaysampler to the coordinate\n", + " # convention of grid_sample.\n", + " images_sampled = torch.nn.functional.grid_sample(\n", + " target_images.permute(0, 3, 1, 2), \n", + " -sampled_rays_xy.view(ba, -1, 1, 2), # note the sign inversion\n", + " align_corners=True\n", + " )\n", + " return images_sampled.permute(0, 2, 3, 1).view(\n", + " ba, *spatial_size, dim\n", + " )\n", + "\n", + "def show_full_render(\n", + " neural_radiance_field, camera,\n", + " target_image, target_silhouette,\n", + " loss_history_color, loss_history_sil,\n", + "):\n", + " \"\"\"\n", + " This is a helper function for visualizing the\n", + " intermediate results of the learning. \n", + " \n", + " Since the `NeuralRadianceField` suffers from\n", + " a large memory footprint, which does not allow to\n", + " render the full image grid in a single forward pass,\n", + " we utilize the `NeuralRadianceField.batched_forward`\n", + " function in combination with disabling the gradient caching.\n", + " This chunks the set of emitted rays to batches and \n", + " evaluates the implicit function on one-batch at a time\n", + " to prevent GPU memory overflow.\n", + " \"\"\"\n", + " \n", + " # Prevent gradient caching.\n", + " with torch.no_grad():\n", + " # Render using the grid renderer and the\n", + " # batched_forward function of neural_radiance_field.\n", + " rendered_image_silhouette, _ = renderer_grid(\n", + " cameras=camera, \n", + " volumetric_function=neural_radiance_field.batched_forward\n", + " )\n", + " # Split the rendering result to a silhouette render\n", + " # and the image render.\n", + " rendered_image, rendered_silhouette = (\n", + " rendered_image_silhouette[0].split([3, 1], dim=-1)\n", + " )\n", + " \n", + " # Generate plots.\n", + " fig, ax = plt.subplots(2, 3, figsize=(15, 10))\n", + " ax = ax.ravel()\n", + " clamp_and_detach = lambda x: x.clamp(0.0, 1.0).cpu().detach().numpy()\n", + " ax[0].plot(list(range(len(loss_history_color))), loss_history_color, linewidth=1)\n", + " ax[1].imshow(clamp_and_detach(rendered_image))\n", + " ax[2].imshow(clamp_and_detach(rendered_silhouette[..., 0]))\n", + " ax[3].plot(list(range(len(loss_history_sil))), loss_history_sil, linewidth=1)\n", + " ax[4].imshow(clamp_and_detach(target_image))\n", + " ax[5].imshow(clamp_and_detach(target_silhouette))\n", + " for ax_, title_ in zip(\n", + " ax,\n", + " (\n", + " \"loss color\", \"rendered image\", \"rendered silhouette\",\n", + " \"loss silhouette\", \"target image\", \"target silhouette\",\n", + " )\n", + " ):\n", + " if not title_.startswith('loss'):\n", + " ax_.grid(\"off\")\n", + " ax_.axis(\"off\")\n", + " ax_.set_title(title_)\n", + " fig.canvas.draw(); fig.show()\n", + " display.clear_output(wait=True)\n", + " display.display(fig)\n", + " return fig\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Fit the radiance field\n", + "\n", + "Here we carry out the radiance field fitting with differentiable rendering.\n", + "\n", + "In order to fit the radiance field, we render it from the viewpoints of the `target_cameras`\n", + "and compare the resulting renders with the observed `target_images` and `target_silhouettes`.\n", + "\n", + "The comparison is done by evaluating the mean huber (smooth-l1) error between corresponding\n", + "pairs of `target_images`/`rendered_images` and `target_silhouettes`/`rendered_silhouettes`.\n", + "\n", + "Since we use the `MonteCarloRaysampler`, the outputs of the training renderer `renderer_mc`\n", + "are colors of pixels that are randomly sampled from the image plane, not a lattice of pixels forming\n", + "a valid image. Thus, in order to compare the rendered colors with the ground truth, we \n", + "utilize the random MonteCarlo pixel locations to sample the ground truth images/silhouettes\n", + "`target_silhouettes`/`rendered_silhouettes` at the xy locations corresponding to the render\n", + "locations. This is done with the helper function `sample_images_at_mc_locs`, which is\n", + "described in the previous cell." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# First move all relevant variables to the correct device.\n", + "renderer_grid = renderer_grid.to(device)\n", + "renderer_mc = renderer_mc.to(device)\n", + "target_cameras = target_cameras.to(device)\n", + "target_images = target_images.to(device)\n", + "target_silhouettes = target_silhouettes.to(device)\n", + "\n", + "# Set the seed for reproducibility\n", + "torch.manual_seed(1)\n", + "\n", + "# Instantiate the radiance field model.\n", + "neural_radiance_field = NeuralRadianceField().to(device)\n", + "\n", + "# Instantiate the Adam optimizer. We set its master learning rate to 1e-3.\n", + "lr = 1e-3\n", + "optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=lr)\n", + "\n", + "# We sample 6 random cameras in a minibatch. Each camera\n", + "# emits raysampler_mc.n_pts_per_image rays.\n", + "batch_size = 6\n", + "\n", + "# 3000 iterations take ~20 min on a Tesla M40 and lead to\n", + "# reasonably sharp results. However, for the best possible\n", + "# results, we recommend setting n_iter=20000.\n", + "n_iter = 3000\n", + "\n", + "# Init the loss history buffers.\n", + "loss_history_color, loss_history_sil = [], []\n", + "\n", + "# The main optimization loop.\n", + "for iteration in range(n_iter): \n", + " # In case we reached the last 75% of iterations,\n", + " # decrease the learning rate of the optimizer 10-fold.\n", + " if iteration == round(n_iter * 0.75):\n", + " print('Decreasing LR 10-fold ...')\n", + " optimizer = torch.optim.Adam(\n", + " neural_radiance_field.parameters(), lr=lr * 0.1\n", + " )\n", + " \n", + " # Zero the optimizer gradient.\n", + " optimizer.zero_grad()\n", + " \n", + " # Sample random batch indices.\n", + " batch_idx = torch.randperm(len(target_cameras))[:batch_size]\n", + " \n", + " # Sample the minibatch of cameras.\n", + " batch_cameras = FoVPerspectiveCameras(\n", + " R = target_cameras.R[batch_idx], \n", + " T = target_cameras.T[batch_idx], \n", + " znear = target_cameras.znear[batch_idx],\n", + " zfar = target_cameras.zfar[batch_idx],\n", + " aspect_ratio = target_cameras.aspect_ratio[batch_idx],\n", + " fov = target_cameras.fov[batch_idx],\n", + " device = device,\n", + " )\n", + " \n", + " # Evaluate the nerf model.\n", + " rendered_images_silhouettes, sampled_rays = renderer_mc(\n", + " cameras=batch_cameras, \n", + " volumetric_function=neural_radiance_field\n", + " )\n", + " rendered_images, rendered_silhouettes = (\n", + " rendered_images_silhouettes.split([3, 1], dim=-1)\n", + " )\n", + " \n", + " # Compute the silhoutte error as the mean huber\n", + " # loss between the predicted masks and the\n", + " # sampled target silhouettes.\n", + " silhouettes_at_rays = sample_images_at_mc_locs(\n", + " target_silhouettes[batch_idx, ..., None], \n", + " sampled_rays.xys\n", + " )\n", + " sil_err = huber(\n", + " rendered_silhouettes, \n", + " silhouettes_at_rays,\n", + " ).abs().mean()\n", + "\n", + " # Compute the color error as the mean huber\n", + " # loss between the rendered colors and the\n", + " # sampled target images.\n", + " colors_at_rays = sample_images_at_mc_locs(\n", + " target_images[batch_idx], \n", + " sampled_rays.xys\n", + " )\n", + " color_err = huber(\n", + " rendered_images, \n", + " colors_at_rays,\n", + " ).abs().mean()\n", + " \n", + " # The optimization loss is a simple\n", + " # sum of the color and silhouette errors.\n", + " loss = color_err + sil_err\n", + " \n", + " # Log the loss history.\n", + " loss_history_color.append(float(color_err))\n", + " loss_history_sil.append(float(sil_err))\n", + " \n", + " # Every 10 iterations, print the current values of the losses.\n", + " if iteration % 10 == 0:\n", + " print(\n", + " f'Iteration {iteration:05d}:'\n", + " + f' loss color = {float(color_err):1.2e}'\n", + " + f' loss silhouette = {float(sil_err):1.2e}'\n", + " )\n", + " \n", + " # Take the optimization step.\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " # Visualize the full renders every 100 iterations.\n", + " if iteration % 100 == 0:\n", + " show_idx = torch.randperm(len(target_cameras))[:1]\n", + " show_full_render(\n", + " neural_radiance_field,\n", + " FoVPerspectiveCameras(\n", + " R = target_cameras.R[show_idx], \n", + " T = target_cameras.T[show_idx], \n", + " znear = target_cameras.znear[show_idx],\n", + " zfar = target_cameras.zfar[show_idx],\n", + " aspect_ratio = target_cameras.aspect_ratio[show_idx],\n", + " fov = target_cameras.fov[show_idx],\n", + " device = device,\n", + " ), \n", + " target_images[show_idx][0],\n", + " target_silhouettes[show_idx][0],\n", + " loss_history_color,\n", + " loss_history_sil,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Visualizing the optimized neural radiance field\n", + "\n", + "Finally, we visualize the neural radiance field by rendering from multiple viewpoints that rotate around the volume's y-axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_rotating_nerf(neural_radiance_field, n_frames = 50):\n", + " logRs = torch.zeros(n_frames, 3, device=device)\n", + " logRs[:, 1] = torch.linspace(-3.14, 3.14, n_frames, device=device)\n", + " Rs = so3_exponential_map(logRs)\n", + " Ts = torch.zeros(n_frames, 3, device=device)\n", + " Ts[:, 2] = 2.7\n", + " frames = []\n", + " print('Rendering rotating NeRF ...')\n", + " for R, T in zip(tqdm(Rs), Ts):\n", + " camera = FoVPerspectiveCameras(\n", + " R=R[None], \n", + " T=T[None], \n", + " znear=target_cameras.znear[0],\n", + " zfar=target_cameras.zfar[0],\n", + " aspect_ratio=target_cameras.aspect_ratio[0],\n", + " fov=target_cameras.fov[0],\n", + " device=device,\n", + " )\n", + " # Note that we again render with `NDCGridSampler`\n", + " # and the batched_forward function of neural_radiance_field.\n", + " frames.append(\n", + " renderer_grid(\n", + " cameras=camera, \n", + " volumetric_function=neural_radiance_field.batched_forward,\n", + " )[0][..., :3]\n", + " )\n", + " return torch.cat(frames)\n", + " \n", + "with torch.no_grad():\n", + " rotating_nerf_frames = generate_rotating_nerf(neural_radiance_field, n_frames=3*5)\n", + " \n", + "image_grid(rotating_nerf_frames.clamp(0., 1.).cpu().numpy(), rows=3, cols=5, rgb=True, fill=True)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Conclusion\n", + "\n", + "In this tutorial, we have shown how to optimize an implicit representation of a scene such that the renders of the scene from known viewpoints match the observed images for each viewpoint. The rendering was carried out using the Pytorch3D's implicit function renderer composed of either a `MonteCarloRaysampler` or `NDCGridRaysampler`, and an `EmissionAbsorptionRaymarcher`." + ] + } + ], + "metadata": { + "bento_stylesheets": { + "bento/extensions/flow/main.css": true, + "bento/extensions/kernel_selector/main.css": true, + "bento/extensions/kernel_ui/main.css": true, + "bento/extensions/new_kernel/main.css": true, + "bento/extensions/system_usage/main.css": true, + "bento/extensions/theme/main.css": true + }, + "kernelspec": { + "display_name": "pytorch3d_etc (local)", + "language": "python", + "name": "pytorch3d_etc_local" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5+" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/utils/generate_cow_renders.py b/docs/tutorials/utils/generate_cow_renders.py index fc080e461..74b08ba6b 100644 --- a/docs/tutorials/utils/generate_cow_renders.py +++ b/docs/tutorials/utils/generate_cow_renders.py @@ -19,12 +19,15 @@ look_at_view_transform, ) + # create the default data directory current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = os.path.join(current_dir, "..", "data", "cow_mesh") -def generate_cow_renders(num_views: int = 40, data_dir: str = DATA_DIR): +def generate_cow_renders( + num_views: int = 40, data_dir: str = DATA_DIR, azimuth_range: float = 180 +): """ This function generates `num_views` renders of a cow mesh. The renders are generated from viewpoints sampled at uniformly distributed @@ -94,7 +97,7 @@ def generate_cow_renders(num_views: int = 40, data_dir: str = DATA_DIR): # Get a batch of viewing angles. elev = torch.linspace(0, 0, num_views) # keep constant - azim = torch.linspace(-180, 180, num_views) + azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0 # Place a point light in front of the object. As mentioned above, the front of # the cow is facing the -z direction.