# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
在本演示中,我们使用 PyTorch3D 中的 VolumeRenderer 作为 Implicitron 中的自定义隐式函数。我们将看到
确保已安装 torch
和 torchvision
。如果未安装 pytorch3d
,请使用以下单元格安装它
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
确保已安装 omegaconf 和 visdom。如果没有,请运行此单元格。(无需重新启动运行时。)
!pip install omegaconf visdom
import logging
from typing import Tuple
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from IPython.display import HTML
from omegaconf import OmegaConf
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase, ImplicitronRayBundle
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import get_default_args, registry, remove_unused_components
from pytorch3d.renderer.implicit.renderer import VolumeSampler
from pytorch3d.structures import Volumes
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene
output_resolution = 80
torch.set_printoptions(sci_mode=False)
Implicitron 中数据集的训练、验证和测试部分表示为 dataset_map
,并由 DatasetMapProvider
的实现提供。RenderedMeshDatasetMapProvider
是一个通过获取网格并对其进行渲染来生成仅包含训练组件的单场景数据集的实现。我们将其与奶牛网格一起使用。
如果使用 **Google Colab** 运行此笔记本,请运行以下单元格以获取网格 obj 和纹理文件并将其保存在 data/cow_mesh 路径下。如果在本地运行,则数据已在正确的路径下可用。
!mkdir -p data/cow_mesh
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl
!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png
cow_provider = RenderedMeshDatasetMapProvider(
data_file="data/cow_mesh/cow.obj",
use_point_light=False,
resolution=output_resolution,
)
dataset_map = cow_provider.get_dataset_map()
tr_cameras = [training_frame.camera for training_frame in dataset_map.train]
# The cameras are all in the XZ plane, in a circle about 2.7 from the origin
centers = torch.cat([i.get_camera_center() for i in tr_cameras])
print(centers.min(0).values)
print(centers.max(0).values)
# visualization of the cameras
plot = plot_scene({"k": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)
plot.layout.scene.aspectmode = "data"
plot
神经渲染方法的核心是称为隐式函数的空间坐标函数,这些函数用于某种渲染过程中。(通常,这些函数还可以接收其他数据,例如视角方向。)常见的渲染过程是在隐式函数提供的密度和颜色上进行光线步进。在我们的例子中,从 3D 体积网格中采样是一个非常简单的空间坐标函数。
在这里,我们定义自己的隐式函数,该函数使用 PyTorch3D 的现有功能从体积网格中采样。我们通过继承 ImplicitFunctionBase
来实现这一点。我们需要使用特殊的装饰器来注册我们的子类。我们使用 Python 的 dataclass 注释来配置模块。
@registry.register
class MyVolumes(ImplicitFunctionBase, torch.nn.Module):
grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction
extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis
def __post_init__(self):
# We have to call this explicitly if there are other base classes like Module
super().__init__()
# We define parameters like other torch.nn.Module objects.
# In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.
density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)
self.density = torch.nn.Parameter(density)
color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)
self.color = torch.nn.Parameter(color)
self.density_activation = torch.nn.Softplus()
def forward(
self,
ray_bundle: ImplicitronRayBundle,
fun_viewpool=None,
global_code=None,
):
densities = self.density_activation(self.density[None, None])
voxel_size = 2.0 * float(self.extent) / self.grid_resolution
features = self.color.sigmoid()[None]
# Like other PyTorch3D structures, the actual Volumes object should only exist as long
# as one iteration of training. It is local to this function.
volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)
sampler = VolumeSampler(volumes=volume)
densities, features = sampler(ray_bundle)
# When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,
# it must return (densities, features, an auxiliary tuple)
return densities, features, {}
PyTorch3D 中的主要模型对象是 GenericModel
,它具有用于主要步骤的可插拔组件,包括渲染器和隐式函数。有两种构造它的方法,在这里它们是等价的。
CONSTRUCT_MODEL_FROM_CONFIG = True
if CONSTRUCT_MODEL_FROM_CONFIG:
# Via a DictConfig - this is how our training loop with hydra works
cfg = get_default_args(GenericModel)
cfg.implicit_function_class_type = "MyVolumes"
cfg.render_image_height=output_resolution
cfg.render_image_width=output_resolution
cfg.loss_weights={"loss_rgb_huber": 1.0}
cfg.tqdm_trigger_threshold=19000
cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0
gm = GenericModel(**cfg)
else:
# constructing GenericModel directly
gm = GenericModel(
implicit_function_class_type="MyVolumes",
render_image_height=output_resolution,
render_image_width=output_resolution,
loss_weights={"loss_rgb_huber": 1.0},
tqdm_trigger_threshold=19000,
raysampler_AdaptiveRaySampler_args = {"scene_extent": 4.0}
)
# In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows
cfg = OmegaConf.structured(gm)
默认渲染器是发射吸收光线步进器。我们保留该默认值。
# We can display the configuration in use as follows.
remove_unused_components(cfg)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
%page -r yaml
device = torch.device("cuda:0")
gm.to(device)
assert next(gm.parameters()).is_cuda
train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]
gm.train()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)
iterator = tqdm.tqdm(range(2000))
for n_batch in iterator:
optimizer.zero_grad()
frame = train_data_collated[n_batch % len(dataset_map.train)]
out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)
out["objective"].backward()
if n_batch % 100 == 0:
iterator.set_postfix_str(f"loss: {float(out['objective']):.5f}")
optimizer.step()
我们从所有视点生成完整的图像,以查看它们的外观。
def to_numpy_image(image):
# Takes an image of shape (C, H, W) in [0,1], where C=3 or 1
# to a numpy uint image of shape (H, W, 3)
return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()
def resize_image(image):
# Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)
return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))
gm.eval()
images = []
expected = []
masks = []
masks_expected = []
for frame in tqdm.tqdm(train_data_collated):
with torch.no_grad():
out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)
image_rgb = to_numpy_image(out["images_render"][0])
mask = to_numpy_image(out["masks_render"][0])
expd = to_numpy_image(resize_image(frame.image_rgb)[0])
mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])
images.append(image_rgb)
masks.append(mask)
expected.append(expd)
masks_expected.append(mask_expected)
我们绘制一个网格,显示每个视点的预测图像和预期图像,以及预测掩码和预期掩码。这是一个包含四行图像的网格,包装成几行较大的行,即。
┌────────┬────────┐ ┌────────┐
│pred │pred │ │pred │
│image │image │ │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│1 │2 │ │n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│image │image │ │image │
│n+1 │n+1 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│image │image │ ... │image │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│pred │pred │ │pred │
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
├────────┼────────┤ ├────────┤
│expected│expected│ │expected│
│mask │mask │ │mask │
│n+1 │n+2 │ │2n │
└────────┴────────┘ └────────┘
...
</center></small>
images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]
n_rows = 4
n_images = len(images)
blank_image = images[0] * 0
n_per_row = 1+(n_images-1)//n_rows
for _ in range(n_per_row*n_rows - n_images):
for group in images_to_display:
group.append(blank_image)
images_to_display_listed = [[[i] for i in j] for j in images_to_display]
split = []
for row in range(n_rows):
for group in images_to_display_listed:
split.append(group[row*n_per_row:(row+1)*n_per_row])
Image.fromarray(np.block(split))
# Print the maximum channel intensity in the first image.
print(images[1].max()/255)
plt.ioff()
fig, ax = plt.subplots(figsize=(3,3))
ax.grid(None)
ims = [[ax.imshow(im, animated=True)] for im in images]
ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)
ani_html = ani.to_jshtml()
HTML(ani_html)
# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate
# with torch.no_grad():
# gm._implicit_functions[0]._fn.density.fill_(9.0)
# gm._implicit_functions[0]._fn.color.fill_(9.0)