批处理
在深度学习中,每个优化步骤都会对多个输入示例进行操作以实现稳健的训练。因此,高效的批处理至关重要。对于图像输入,批处理非常简单;N 个图像调整为相同的高度和宽度,并作为形状为N x 3 x H x W
的 4 维张量堆叠。对于网格,批处理则不太简单。
网格的批处理模式
假设您想要构建一个包含两个网格的批次,其中mesh1 = (v1: V1 x 3, f1: F1 x 3)
包含V1
个顶点和F1
个面,而mesh2 = (v2: V2 x 3, f2: F2 x 3)
包含V2 (!= V1)
个顶点和F2 (!= F1)
个面。Meshes 数据结构提供了三种不同的方法来批处理异构网格。如果meshes = Meshes(verts = [v1, v2], faces = [f1, f2])
是数据结构的实例化,则
- 列表:将批次中的示例作为张量列表返回。具体来说,
meshes.verts_list()
返回顶点列表[v1, v2]
。类似地,meshes.faces_list()
返回面列表[f1, f2]
。 - 填充:填充表示通过填充额外值来构建张量。具体来说,
meshes.verts_padded()
返回形状为2 x max(V1, V2) x 3
的张量,并用0
填充额外的顶点。类似地,meshes.faces_padded()
返回形状为2 x max(F1, F2) x 3
的张量,并用-1
填充额外的面。 - 打包:打包表示将批次中的示例连接到一个张量中。特别是,
meshes.verts_packed()
返回形状为(V1 + V2) x 3
的张量。类似地,meshes.faces_packed()
返回形状为(F1 + F2) x 3
的张量,用于表示面。在打包模式下,会计算辅助变量,这些变量能够在打包模式、填充模式或列表模式之间进行高效转换。
批处理模式的用例
对不同网格批处理模式的需求是 PyTorch 运算符实现方式所固有的。为了充分利用优化的 PyTorch 操作,Meshes 数据结构允许在不同的批处理模式之间进行高效转换。这在旨在实现快速高效的训练周期时至关重要。一个例子是Mesh R-CNN。在这里,在同一个前向传递中,网络的不同部分假设不同的输入,这些输入是通过在不同的批处理模式之间转换计算出来的。特别是,vert_align 假设一个填充的输入张量,而紧随其后的graph_conv 假设一个打包的输入张量。