> 技术文档 > Pytorch中register_buffer和torch.nn.Parameter的异同

Pytorch中register_buffer和torch.nn.Parameter的异同

说下register_buffer和Parameter的异同

相同点

方面 描述 追踪 都会被加入 state_dict模型保存时会保存下来)。 与 Module 的绑定 都会随着模型移动到 cuda / cpu / float() 等而自动迁移。 都是 nn.Module 的一部分 都可以通过模块属性访问,如 self.x

不同点

方面 torch.nn.Parameter register_buffer 是否是可训练参数 ✅ 是,会被视为模型需要优化的参数(model.parameters() 中包含) ❌ 否,不会被优化器更新 梯度计算 默认 requires_grad=True,参与反向传播 默认 requires_grad=False,不参与反向传播 用途场景 模型的权重、偏置等需要学习的参数 均值、方差、mask、位置编码等常量或状态,如 BatchNorm 中的 running mean/var 注册方式 self.w = nn.Parameter(tensor)self.register_parameter(\"w\", nn.Parameter(...)) self.register_buffer(\"buf\", tensor) 是否显示在 parameters() ✅ 会显示 ❌ 不会显示 是否能直接赋值注册 ✅ 可以直接赋值 ❌ 必须通过 register_buffer() 注册,否则不会记录到 state_dict

使用建议

情境 推荐使用 需要优化 nn.Parameter 只做记录或参与计算但不优化 register_buffer 实现自定义模块(如 BatchNorm)时的状态 register_buffer 使用位置编码、attention mask register_buffer 模型保存中需要但不训练 register_buffer

这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。

import torchimport torch.nn as nnimport torch.nn.functional as Fclass ToyModel(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) return outclass ToyModel2(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckpt def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) out += self.b1 return outclass ToyModel3(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() self.b1 = nn.Parameter(torch.randn(outChannels),) self.register_buffer(\"c1\", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问 def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) out += self.b1 out += self.c1 return out
import torchimport torch.nn as nnimport torch.nn.functional as Fimport loggingfrom pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging.basicConfig(level=logging.INFO,  format=\'%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s\')if __name__ == \"__main__\": savePath = Path(\"toymodel3.pth\") logger = logging.getLogger(__name__) inp = torch.randn(3, 5) model = ToyModel3(inp.size(1), inp.size(1) * 2) pred = model(inp) logger.info(f\"{pred.size()=}\") for m in model.modules(): logger.info(m) for name, param in model.named_parameters(): logger.info(f\"{name = }, {param.size() = }, {param.requires_grad=}\") for name, buffer in model.named_buffers(): logger.info(f\"{name = }, {buffer.size() = }\") torch.save(model.state_dict(), savePath)
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == \"__main__\": savePath = Path(\"toymodel3.pth\") inp = torch.randn(3, 5) model = ToyModel3(inp.size(1), inp.size(1) * 2) ckpt = torch.load(savePath, map_location=\"cpu\", weights_only=True) model.load_state_dict(ckpt) pred = model(inp) print(f\"{pred.size()=}\") for m in model.modules(): print(m) for name, param in model.named_parameters(): print(f\"{name = }, {param.size() = }, {param.requires_grad=}\") for name, buffer in model.named_buffers(): print(f\"{name = }, {buffer.size() = }\")