# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
Implicitron 的组件都基于一个统一的分层配置系统。这允许为每个新组件分别定义可配置变量和所有默认值。然后,与实验相关的所有配置都会自动组合成一个完整指定实验的单个配置文件。一个特别重要的功能是扩展点,用户可以在其中插入 Implicitron 基础组件自己的子类。
定义此系统的文件在此处位于 PyTorch3D 代码库中。Implicitron 体积教程包含一个使用配置系统的简单示例。本教程提供了使用和修改 Implicitron 可配置组件的详细实践经验。
确保已安装 torch
和 torchvision
。如果未安装 pytorch3d
,请使用以下单元格安装它
import os
import sys
import torch
need_pytorch3d=False
try:
import pytorch3d
except ModuleNotFoundError:
need_pytorch3d=True
if need_pytorch3d:
if torch.__version__.startswith("2.2.") and sys.platform.startswith("linux"):
# We try to install PyTorch3D via a released wheel.
pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
version_str="".join([
f"py3{sys.version_info.minor}_cu",
torch.version.cuda.replace(".",""),
f"_pyt{pyt_version_str}"
])
!pip install fvcore iopath
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
else:
# We try to install PyTorch3D from source.
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
确保已安装 omegaconf。如果没有,请运行此单元格。(无需重新启动运行时。)
!pip install omegaconf
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
expand_args_fields,
get_default_args,
registry,
run_auto_creation,
)
类型提示提供了 Python 中类型的分类。Dataclasses允许您基于具有名称、类型和可能默认值的成员列表创建类。__init__
函数会自动创建,并在存在的情况下作为最后一步调用 __post_init__
函数。例如
@dataclass
class MyDataclass:
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
my_dataclass_instance = MyDataclass(a=18)
assert my_dataclass_instance.d == 16
👷 请注意,此处的 dataclass
装饰器是一个修改类本身定义的函数。它在定义后立即运行。我们的配置系统要求 implicitron 库代码包含其修改版本需要了解用户定义实现的类。因此,我们需要延迟类的修改。我们不使用装饰器。
dc = DictConfig({"a": 2, "b": True, "c": None, "d": "hello"})
assert dc.a == dc["a"] == 2
OmegaConf 具有对 yaml 的序列化和反序列化功能。Hydra 库依赖于此来实现其配置文件。
print(OmegaConf.to_yaml(dc))
assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc
OmegaConf.structured 从 dataclass 或 dataclass 的实例提供 DictConfig。与普通的 DictConfig 不同,它经过类型检查,并且只能添加已知的键。
structured = OmegaConf.structured(MyDataclass)
assert isinstance(structured, DictConfig)
print(structured)
print()
print(OmegaConf.to_yaml(structured))
structured
知道它缺少 a
的值。
这样的对象具有与 dataclass 兼容的成员,因此可以按如下方式执行初始化。
structured.a = 21
my_dataclass_instance2 = MyDataclass(**structured)
print(my_dataclass_instance2)
您也可以对实例调用 OmegaConf.structured。
structured_from_instance = OmegaConf.structured(my_dataclass_instance)
my_dataclass_instance3 = MyDataclass(**structured_from_instance)
print(my_dataclass_instance3)
我们提供了等效于 OmegaConf.structured
但支持更多功能的函数。要使用我们的函数实现上述功能,请使用以下方法。请注意,我们使用特殊的基类 Configurable
(而不是装饰器)来指示可配置类。
class MyConfigurable(Configurable):
a: int
b: int = 8
c: Optional[Tuple[int, ...]] = None
def __post_init__(self):
print(f"created with a = {self.a}")
self.d = 2 * self.b
# The expand_args_fields function modifies the class like @dataclasses.dataclass.
# If it has not been called on a Configurable object before it has been instantiated, it will
# be called automatically.
expand_args_fields(MyConfigurable)
my_configurable_instance = MyConfigurable(a=18)
assert my_configurable_instance.d == 16
# get_default_args also calls expand_args_fields automatically
our_structured = get_default_args(MyConfigurable)
assert isinstance(our_structured, DictConfig)
print(OmegaConf.to_yaml(our_structured))
our_structured.a = 21
print(MyConfigurable(**our_structured))
我们的系统允许 Configurable 类包含彼此。需要记住的一件事:在 __post_init__
中添加对 run_auto_creation
的调用。
class Inner(Configurable):
a: int = 8
b: bool = True
c: Tuple[int, ...] = (2, 3, 4, 6)
class Outer(Configurable):
inner: Inner
x: str = "hello"
xx: bool = False
def __post_init__(self):
run_auto_creation(self)
outer_dc = get_default_args(Outer)
print(OmegaConf.to_yaml(outer_dc))
outer = Outer(**outer_dc)
assert isinstance(outer, Outer)
assert isinstance(outer.inner, Inner)
print(vars(outer))
print(outer.inner)
请注意,inner_args 是 outer 的一个额外成员。run_auto_creation(self)
等效于
self.inner = Inner(**self.inner_args)
如果一个类使用 ReplaceableBase
作为基类而不是 Configurable
,我们称之为可替换的。它表示它旨在供子类在其位置使用。我们可能会使用 NotImplementedError
来指示子类预期实现的功能。系统维护一个全局 registry
,其中包含每个 ReplaceableBase 的子类。子类使用装饰器在其中注册自身。
包含 ReplaceableBase 的可配置类(即使用我们系统的类,即 Configurable
或 ReplaceableBase
的子类)还必须包含一个对应类型为 str
的 class_type
字段,该字段指示要使用的具体子类。
class InnerBase(ReplaceableBase):
def say_something(self):
raise NotImplementedError
@registry.register
class Inner1(InnerBase):
a: int = 1
b: str = "h"
def say_something(self):
print("hello from an Inner1")
@registry.register
class Inner2(InnerBase):
a: int = 2
def say_something(self):
print("hello from an Inner2")
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
Out_dc = get_default_args(Out)
print(OmegaConf.to_yaml(Out_dc))
Out_dc.inner_class_type = "Inner2"
out = Out(**Out_dc)
print(out.inner)
out.talk()
请注意,在这种情况下,有很多 args
成员。在代码中通常可以忽略它们。它们是配置所需的。
print(vars(out))
class MyLinear(torch.nn.Module, Configurable):
d_in: int = 2
d_out: int = 200
def __post_init__(self):
super().__init__()
self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)
def forward(self, x):
return self.linear.forward(x)
my_linear = MyLinear()
input = torch.zeros(2)
output = my_linear(input)
print("output shape:", output.shape)
my_linear
具有 Module 的所有常用功能。例如,它可以使用 torch.save
和 torch.load
保存和加载。它有参数
for name, value in my_linear.named_parameters():
print(name, value.shape)
假设我正在使用一个库,其中 Out
类似于第 5 节,但我希望实现我自己的 InnerBase 子类。我需要做的就是在其定义之前注册它,但我需要在 Out 上显式或隐式调用 expand_args_fields 之前执行此操作。
@registry.register
class UserImplementedInner(InnerBase):
a: int = 200
def say_something(self):
print("hello from the user")
此时,我们需要重新定义类 Out。否则,如果它已经在没有 UserImplementedInner 的情况下展开,那么以下操作将不起作用,因为类的已知实现是在其展开时固定的。
如果您正在从脚本运行实验,那么这里需要记住的是,您必须在使用库类之前导入您自己的模块,这些模块会注册您自己的实现。
class Out(Configurable):
inner: InnerBase
inner_class_type: str = "Inner1"
x: int = 19
def __post_init__(self):
run_auto_creation(self)
def talk(self):
self.inner.say_something()
out2 = Out(inner_class_type="UserImplementedInner")
print(out2.inner)
让我们看看如果我们有一个要使其可插拔的子组件,以允许用户提供自己的子组件,需要发生什么。
class SubComponent(Configurable):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponent
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
通用化
class SubComponentBase(ReplaceableBase):
def apply(self, a: float) -> float:
raise NotImplementedError
@registry.register
class SubComponent(SubComponentBase):
x: float = 0.25
def apply(self, a: float) -> float:
return a + self.x
class LargeComponent(Configurable):
repeats: int = 4
subcomponent: SubComponentBase
subcomponent_class_type: str = "SubComponent"
def __post_init__(self):
run_auto_creation(self)
def apply(self, a: float) -> float:
for _ in range(self.repeats):
a = self.subcomponent.apply(a)
return a
large_component = LargeComponent()
assert large_component.apply(3) == 4
print(OmegaConf.to_yaml(LargeComponent))
以下内容发生了变化
@registry.register
装饰,并将基类更改为新的基类。subcomponent_class_type
作为外部类的成员添加。subcomponent_args
必须更改为 subcomponent_SubComponent_args
。__post_init__
或在其中不调用 run_auto_creation
。 subcomponent_class_type = "SubComponent"
而不是 subcomponent_class_type: str = "SubComponent"