> 技术文档 > 【PyTorch】torch.multiprocessing 模块:在多个进程间共享和操作张量、模型等资源

【PyTorch】torch.multiprocessing 模块:在多个进程间共享和操作张量、模型等资源

torch.multiprocessing 是 PyTorch 提供的一个模块,用于在多个进程间共享和操作张量、模型等资源,支持并行计算和训练。它是对 Python 标准库 multiprocessing 的扩展,特别优化了与 PyTorch 张量和 CUDA 的集成,适用于多进程任务,如数据加载、模型训练等。


1. 核心功能

  • 多进程管理:创建和管理多个进程,类似于 multiprocessing,但针对 PyTorch 进行了优化。
  • 张量共享:支持在进程间共享 PyTorch 张量,包括 CPU 和 CUDA 张量。
  • 与 CUDA 集成:处理多进程中 CUDA 张量的共享和同步,避免常见问题(如 fork 与 CUDA 的冲突)。
  • 用途:常用于数据并行加载、分布式训练预处理或自定义多进程任务。

2. 关键组件

2.1. 上下文和启动方法

torch.multiprocessing 继承了 Python 的 multiprocessing,支持三种启动方法:

  • spawn (默认):
    • 创建新进程,不继承父进程状态。
    • 推荐用于 CUDA,安全且避免 fork 相关问题。
  • fork
    • 复制父进程状态,可能与 CUDA 冲突,不建议使用。
  • forkserver
    • 使用独立服务器进程启动子进程,较安全但较少用。
  • 设置方法
    import torch.multiprocessing as mpmp.set_start_method(\'spawn\') # 建议在程序开始时设置
2.2. 进程类
  • mp.Process
    • 创建子进程,类似 multiprocessing.Process
    • 参数:target(函数)、args(参数)、kwargs(关键字参数)。
    • 示例:
      def worker(rank): print(f\"Worker {rank} running\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') processes = [] for i in range(4): p = mp.Process(target=worker, args=(i,)) processes.append(p) p.start() for p in processes: p.join()
2.3. 张量共享
  • 机制:PyTorch 提供特殊方法在进程间共享张量,避免重复复制。
  • CPU 张量
    • 使用 tensor.share_memory_() 标记张量为可共享。
    • 共享内存,进程可直接访问和修改。
  • CUDA 张量
    • spawn 方法下,CUDA 张量不能直接共享,但可通过队列或管道传递。
  • 示例
    import torchimport torch.multiprocessing as mpdef worker(tensor, rank): tensor[rank] = rank # 修改共享张量 print(f\"Worker {rank} set tensor[{rank}] = {rank}\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') tensor = torch.zeros(4).share_memory_() # 标记为共享 processes = [] for i in range(4): p = mp.Process(target=worker, args=(tensor, i)) processes.append(p) p.start() for p in processes: p.join() print(f\"Final tensor: {tensor}\")
    • 输出:每个进程修改共享张量,最终结果反映所有更改。
2.4. 队列和管道
  • mp.Queue
    • 进程间传递数据(如张量、模型参数)。
    • 适合 CPU 张量,CUDA 张量需小心处理。
  • mp.Pipe
    • 双向通信通道,用于进程间直接交换数据。
  • 示例
    def producer(queue): queue.put(torch.tensor([1, 2, 3]))def consumer(queue): data = queue.get() print(f\"Received: {data}\")if __name__ == \'__main__\': mp.set_start_method(\'spawn\') queue = mp.Queue() p1 = mp.Process(target=producer, args=(queue,)) p2 = mp.Process(target=consumer, args=(queue,)) p1.start(); p2.start() p1.join(); p2.join()

3. 与分布式训练的关系

  • 结合 torch.distributed
    • torch.multiprocessing 常与 torch.distributed 配合,启动多个进程用于分布式数据并行(DDP)。
    • 例如,torchrun 使用 torch.multiprocessing 在后台管理进程。
  • 示例 (DDP):
    import torchimport torch.multiprocessing as mpimport torch.distributed as distimport torch.nn as nnfrom torch.utils.data import DataLoader, DistributedSamplerclass YourModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 1) def forward(self, x): return self.fc(x)def train(rank, world_size): dist.init_process_group(backend=\'nccl\', rank=rank, world_size=world_size) torch.cuda.set_device(rank) model = YourModel().to(rank) model = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) dataset = YourDataset() sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) loader = DataLoader(dataset, batch_size=32, sampler=sampler) criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for data, target in loader: data, target = data.to(rank), target.to(rank) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() dist.destroy_process_group()if __name__ == \'__main__\': mp.set_start_method(\'spawn\') world_size = 4 # 假设 4 个 GPU mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
  • mp.spawn
    • 启动多个进程,自动分配 rank(0 到 nprocs-1)。
    • 参数:fn(目标函数)、args(参数)、nprocs(进程数)。

4. 注意事项

  • 启动方法
    • 始终使用 spawn 作为默认方法,尤其在使用 CUDA 时,避免 fork 导致的内存或设备错误。
    • 在程序开头调用 mp.set_start_method(\'spawn\'),只设置一次。
  • CUDA 限制
    • CUDA 张量不能直接共享内存,需通过队列传递或在每个进程中重新创建。
    • 确保每个进程调用 torch.cuda.set_device 绑定不同 GPU。
  • 保护主模块
    • 代码需放在 if __name__ == \'__main__\': 下,防止子进程重复执行主逻辑。
  • 资源管理
    • 进程结束后调用 join() 确保清理。
    • 避免内存泄漏,谨慎共享大张量。

5. 总结

torch.multiprocessing 是 PyTorch 的多进程工具,扩展了 Python 的 multiprocessing,优化了张量共享和 CUDA 集成。它支持进程创建、数据共享(如通过 share_memory_() 或队列),常用于数据加载或与 torch.distributed 结合进行分布式训练。关键是选择 spawn 启动方法、正确管理设备和资源。

当当礼券