diff --git a/docs/notes/renderer_getting_started.md b/docs/notes/renderer_getting_started.md index e28cd25b7..ae0b95271 100644 --- a/docs/notes/renderer_getting_started.md +++ b/docs/notes/renderer_getting_started.md @@ -84,7 +84,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi 1. **Vertex Textures**: D dimensional textures for each vertex (for example an RGB color) which can be interpolated across the face. This can be represented as an `(N, V, D)` tensor. This is a fairly simple representation though and cannot model complex textures if the mesh faces are large. 2. **UV Textures**: vertex UV coordinates and **one** texture map for the whole mesh. For a point on a face with given barycentric coordinates, the face color can be computed by interpolating the vertex uv coordinates and then sampling from the texture map. This representation requires two tensors (UVs: `(N, V, 2), Texture map: `(N, H, W, 3)`), and is limited to only support one texture map per mesh. -3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions. +3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions. **NOTE:**: The `TextureAtlas` texture sampling is only differentiable with respect to the texture atlas but not differentiable with respect to the barycentric coordinates. + diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index 9a58b841b..9af8611e7 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -479,6 +479,18 @@ def extend(self, N: int) -> "TexturesAtlas": def sample_textures(self, fragments, **kwargs) -> torch.Tensor: """ + This is similar to a nearest neighbor sampling and involves a + discretization step. The barycentric coordinates from + rasterization are used to find the nearest grid cell in the texture + atlas and the RGB is returned as the color. + This means that this step is differentiable with respect to the RGB + values of the texture atlas but not differentiable with respect to the + barycentric coordinates. + + TODO: Add a different sampling mode which interpolates the barycentric + coordinates to sample the texture and will be differentiable w.r.t + the barycentric coordinates. + Args: fragments: The outputs of rasterization. From this we use @@ -504,7 +516,10 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor: # pyre-fixme[16]: `bool` has no attribute `__getitem__`. mask = (pix_to_face < 0)[..., None] bary_w01 = torch.where(mask, torch.zeros_like(bary_w01), bary_w01) - w_xy = (bary_w01 * R).to(torch.int64) # (N, H, W, K, 2) + # If barycentric coordinates are > 1.0 (in the case of + # blur_radius > 0.0), wxy might be > R. We need to clamp this + # index to R-1 to index into the texture atlas. + w_xy = (bary_w01 * R).to(torch.int64).clamp(max=R - 1) # (N, H, W, K, 2) below_diag = ( bary_w01.sum(dim=-1) * R - w_xy.float().sum(dim=-1) diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 344b32d30..5c46bd250 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -956,6 +956,7 @@ def test_joined_spheres(self): def test_texture_map_atlas(self): """ Test a mesh with a texture map as a per face atlas is loaded and rendered correctly. + Also check that the backward pass for texture atlas rendering is differentiable. """ device = torch.device("cuda:0") obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data" @@ -970,10 +971,11 @@ def test_texture_map_atlas(self): texture_atlas_size=8, texture_wrap=None, ) + atlas = aux.texture_atlas mesh = Meshes( verts=[verts], faces=[faces.verts_idx], - textures=TexturesAtlas(atlas=[aux.texture_atlas]), + textures=TexturesAtlas(atlas=[atlas]), ) # Init rasterizer settings @@ -981,7 +983,10 @@ def test_texture_map_atlas(self): cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( - image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True + image_size=512, + blur_radius=0.0, + faces_per_pixel=1, + cull_backfaces=True, ) # Init shader settings @@ -993,23 +998,52 @@ def test_texture_map_atlas(self): lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None] # The HardPhongShader can be used directly with atlas textures. + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) renderer = MeshRenderer( - rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), + rasterizer=rasterizer, shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials), ) images = renderer(mesh) - rgb = images[0, ..., :3].squeeze().cpu() + rgb = images[0, ..., :3].squeeze() # Load reference image image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR) if DEBUG: - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + Image.fromarray((rgb.detach().cpu().numpy() * 255).astype(np.uint8)).save( DATA_DIR / "DEBUG_texture_atlas_8x8_back.png" ) - self.assertClose(rgb, image_ref, atol=0.05) + self.assertClose(rgb.cpu(), image_ref, atol=0.05) + + # Check gradients are propagated + # correctly back to the texture atlas. + # Because of how texture sampling is implemented + # for the texture atlas it is not possible to get + # gradients back to the vertices. + atlas.requires_grad = True + mesh = Meshes( + verts=[verts], + faces=[faces.verts_idx], + textures=TexturesAtlas(atlas=[atlas]), + ) + raster_settings = RasterizationSettings( + image_size=512, + blur_radius=0.0001, + faces_per_pixel=5, + cull_backfaces=True, + clip_barycentric_coords=True, + ) + images = renderer(mesh, raster_settings=raster_settings) + images[0, ...].sum().backward() + + fragments = rasterizer(mesh, raster_settings=raster_settings) + # Some of the bary coordinates are outisde the + # [0, 1] range as expected because the blur is > 0 + self.assertTrue(fragments.bary_coords.ge(1.0).any()) + self.assertIsNotNone(atlas.grad) + self.assertTrue(atlas.grad.sum().abs() > 0.0) def test_simple_sphere_outside_zfar(self): """