Pytorch中register_buffer和torch.nn.Parameter的异同
说下register_buffer和Parameter的异同
相同点
state_dict(模型保存时会保存下来)。Module 的绑定cuda / cpu / float() 等而自动迁移。nn.Module 的一部分self.x。不同点
torch.nn.Parameterregister_buffermodel.parameters() 中包含)requires_grad=True,参与反向传播requires_grad=False,不参与反向传播self.w = nn.Parameter(tensor) 或 self.register_parameter(\"w\", nn.Parameter(...))self.register_buffer(\"buf\", tensor)parameters() 中register_buffer() 注册,否则不会记录到 state_dict使用建议
nn.Parameterregister_bufferregister_bufferregister_bufferregister_buffer这里我自己写了一个测试代码,分别运行ToyModel1 2 3 保存并读取,相信会对这两个函数有很深刻的认识。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass ToyModel(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) return outclass ToyModel2(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() self.b1 = nn.Parameter(torch.randn(outChannels),) # 模型参数,requires_grad=True, 保存进ckpt def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) out += self.b1 return outclass ToyModel3(nn.Module): def __init__(self, inChannels, outChannels): super().__init__() self.a1 = 1 # 实例成员,不会保存在ckpt中 self.a2 = 2 self.linear = nn.Linear(inChannels, outChannels) self.init_weights() self.b1 = nn.Parameter(torch.randn(outChannels),) self.register_buffer(\"c1\", torch.ones_like(self.b1), persistent=True) # 类成员,requires_grad=False, 保存进ckpt,用于保存需要直接计算的常量,可以用self.c1访问 def init_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): out = self.linear(x) out += self.b1 out += self.c1 return out
import torchimport torch.nn as nnimport torch.nn.functional as Fimport loggingfrom pathlib import Pathfrom models import ToyModel2, ToyModel, ToyModel3logging.basicConfig(level=logging.INFO, format=\'%(asctime)s - %(name)s - %(levelname)s - %(lineno)s - %(message)s\')if __name__ == \"__main__\": savePath = Path(\"toymodel3.pth\") logger = logging.getLogger(__name__) inp = torch.randn(3, 5) model = ToyModel3(inp.size(1), inp.size(1) * 2) pred = model(inp) logger.info(f\"{pred.size()=}\") for m in model.modules(): logger.info(m) for name, param in model.named_parameters(): logger.info(f\"{name = }, {param.size() = }, {param.requires_grad=}\") for name, buffer in model.named_buffers(): logger.info(f\"{name = }, {buffer.size() = }\") torch.save(model.state_dict(), savePath)
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom pathlib import Pathfrom models import ToyModel, ToyModel2, ToyModel3if __name__ == \"__main__\": savePath = Path(\"toymodel3.pth\") inp = torch.randn(3, 5) model = ToyModel3(inp.size(1), inp.size(1) * 2) ckpt = torch.load(savePath, map_location=\"cpu\", weights_only=True) model.load_state_dict(ckpt) pred = model(inp) print(f\"{pred.size()=}\") for m in model.modules(): print(m) for name, param in model.named_parameters(): print(f\"{name = }, {param.size() = }, {param.requires_grad=}\") for name, buffer in model.named_buffers(): print(f\"{name = }, {buffer.size() = }\")


