【PyTorch】torch.multiprocessing 模块:在多个进程间共享和操作张量、模型等资源
torch.multiprocessing
是 PyTorch 提供的一个模块,用于在多个进程间共享和操作张量、模型等资源,支持并行计算和训练。它是对 Python 标准库 multiprocessing
的扩展,特别优化了与 PyTorch 张量和 CUDA 的集成,适用于多进程任务,如数据加载、模型训练等。
1. 核心功能
- 多进程管理:创建和管理多个进程,类似于
multiprocessing
,但针对 PyTorch 进行了优化。 - 张量共享:支持在进程间共享 PyTorch 张量,包括 CPU 和 CUDA 张量。
- 与 CUDA 集成:处理多进程中 CUDA 张量的共享和同步,避免常见问题(如 fork 与 CUDA 的冲突)。
- 用途:常用于数据并行加载、分布式训练预处理或自定义多进程任务。
2. 关键组件
2.1. 上下文和启动方法
torch.multiprocessing
继承了 Python 的 multiprocessing
,支持三种启动方法:
- spawn (默认):
- 创建新进程,不继承父进程状态。
- 推荐用于 CUDA,安全且避免 fork 相关问题。
- fork:
- 复制父进程状态,可能与 CUDA 冲突,不建议使用。
- forkserver:
- 使用独立服务器进程启动子进程,较安全但较少用。
- 设置方法:
import torch.multiprocessing as mpmp.set_start_method(\'spawn\') # 建议在程序开始时设置
2.2. 进程类
- mp.Process:
- 创建子进程,类似
multiprocessing.Process
。 - 参数:
target
(函数)、args
(参数)、kwargs
(关键字参数)。 - 示例:
def worker(rank): print(f\"Worker {rank} running\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') processes = [] for i in range(4): p = mp.Process(target=worker, args=(i,)) processes.append(p) p.start() for p in processes: p.join()
- 创建子进程,类似
2.3. 张量共享
- 机制:PyTorch 提供特殊方法在进程间共享张量,避免重复复制。
- CPU 张量:
- 使用
tensor.share_memory_()
标记张量为可共享。 - 共享内存,进程可直接访问和修改。
- 使用
- CUDA 张量:
spawn
方法下,CUDA 张量不能直接共享,但可通过队列或管道传递。
- 示例:
import torchimport torch.multiprocessing as mpdef worker(tensor, rank): tensor[rank] = rank # 修改共享张量 print(f\"Worker {rank} set tensor[{rank}] = {rank}\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') tensor = torch.zeros(4).share_memory_() # 标记为共享 processes = [] for i in range(4): p = mp.Process(target=worker, args=(tensor, i)) processes.append(p) p.start() for p in processes: p.join() print(f\"Final tensor: {tensor}\")
- 输出:每个进程修改共享张量,最终结果反映所有更改。
2.4. 队列和管道
- mp.Queue:
- 进程间传递数据(如张量、模型参数)。
- 适合 CPU 张量,CUDA 张量需小心处理。
- mp.Pipe:
- 双向通信通道,用于进程间直接交换数据。
- 示例:
def producer(queue): queue.put(torch.tensor([1, 2, 3]))def consumer(queue): data = queue.get() print(f\"Received: {data}\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') queue = mp.Queue() p1 = mp.Process(target=producer, args=(queue,)) p2 = mp.Process(target=consumer, args=(queue,)) p1.start(); p2.start() p1.join(); p2.join()
3. 与分布式训练的关系
- 结合
torch.distributed
:torch.multiprocessing
常与torch.distributed
配合,启动多个进程用于分布式数据并行(DDP)。- 例如,
torchrun
使用torch.multiprocessing
在后台管理进程。
- 示例 (DDP):
import torchimport torch.multiprocessing as mpimport torch.distributed as distimport torch.nn as nnfrom torch.utils.data import DataLoader, DistributedSamplerclass YourModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 1) def forward(self, x): return self.fc(x)def train(rank, world_size): dist.init_process_group(backend=\'nccl\', rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = YourModel().to(rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) dataset = YourDataset() sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) loader = DataLoader(dataset, batch_size=32, sampler=sampler) criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for data, target in loader: data, target = data.to(rank), target.to(rank) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() dist.destroy_process_group()if __name__ == \'__main__\': mp.set_start_method(\'spawn\') world_size = 4 # 假设 4 个 GPU mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
- mp.spawn:
- 启动多个进程,自动分配
rank
(0 到 nprocs-1)。 - 参数:
fn
(目标函数)、args
(参数)、nprocs
(进程数)。
- 启动多个进程,自动分配
4. 注意事项
- 启动方法:
- 始终使用
spawn
作为默认方法,尤其在使用 CUDA 时,避免 fork 导致的内存或设备错误。 - 在程序开头调用
mp.set_start_method(\'spawn\')
,只设置一次。
- 始终使用
- CUDA 限制:
- CUDA 张量不能直接共享内存,需通过队列传递或在每个进程中重新创建。
- 确保每个进程调用
torch.cuda.set_device
绑定不同 GPU。
- 保护主模块:
- 代码需放在
if __name__ == \'__main__\':
下,防止子进程重复执行主逻辑。
- 代码需放在
- 资源管理:
- 进程结束后调用
join()
确保清理。 - 避免内存泄漏,谨慎共享大张量。
- 进程结束后调用
5. 总结
torch.multiprocessing
是 PyTorch 的多进程工具,扩展了 Python 的 multiprocessing
,优化了张量共享和 CUDA 集成。它支持进程创建、数据共享(如通过 share_memory_()
或队列),常用于数据加载或与 torch.distributed
结合进行分布式训练。关键是选择 spawn
启动方法、正确管理设备和资源。