> 技术文档 > 【PyTorch】torch.matmul() 函数: 矩阵乘法(矩阵点积)_torch.matmul函数

【PyTorch】torch.matmul() 函数: 矩阵乘法(矩阵点积)_torch.matmul函数


torch.matmul() 函数详解

torch.matmul() 是 PyTorch 中用于执行 矩阵乘法(矩阵点积) 的函数,支持 1D、2D、3D 及更高维度张量的广义矩阵乘法,是深度学习中非常常用的线性代数运算。


1. 基本语法

torch.matmul(tensor1, tensor2) → Tensor
参数 说明 tensor1, tensor2 要相乘的两个张量(维度可以是 1D、2D 或更高) 返回值 两个张量的矩阵乘积,遵循广播规则

2. 不同维度输入的行为

2.1 两个 1D 张量(向量点积)

a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])result = torch.matmul(a, b)print(result) # 输出: tensor(32)

计算的是:
1×4+2×5+3×6=321 \\times 4 + 2 \\times 5 + 3 \\times 6 = 321×4+2×5+3×6=32


2.2 1D × 2D 或 2D × 1D(向量和矩阵)

# 1D × 2D,向量乘矩阵a = torch.tensor([1.0, 2.0]) # shape: (2,)b = torch.tensor([[3.0, 4.0],  [5.0, 6.0]]) # shape: (2, 2)print(torch.matmul(a, b))  # shape: (2,)# 输出: tensor([13., 16.])

[1,2]×[3456]=[13,16][1, 2] \\times \\begin{bmatrix}3 & 4\\\\5 & 6\\end{bmatrix} = [13, 16][1,2]×[3546]=[13,16]

# 2D × 1D,矩阵乘向量a = torch.tensor([[3.0, 4.0],  [5.0, 6.0]]) # shape: (2, 2)b = torch.tensor([1.0, 2.0]) # shape: (2,)print(torch.matmul(a, b))  # shape: (2,)# 输出: tensor([11., 17.])

2.3 两个 2D 张量(标准矩阵乘法)

a = torch.tensor([[1, 2], [3, 4]]) # (2, 2)b = torch.tensor([[5, 6], [7, 8]]) # (2, 2)print(torch.matmul(a, b))# 输出:# tensor([[19, 22],# [43, 50]])

2.4 批量矩阵乘法(3D 或更高)

a = torch.randn(10, 3, 4) # 表示 10 个 (3×4) 的矩阵b = torch.randn(10, 4, 5) # 表示 10 个 (4×5) 的矩阵result = torch.matmul(a, b) # 输出: (10, 3, 5)

广播规则:
对于更高维度的张量,matmul 会按照批次维度进行 广播并执行批矩阵乘法


3. 与 @ 运算符等价

a = torch.randn(2, 3)b = torch.randn(3, 4)# 等价写法result1 = torch.matmul(a, b)result2 = a @ b

4. 与 torch.mm()torch.bmm() 区别

函数 适用维度 描述 torch.matmul() 任意维度 广义矩阵乘法,推荐使用 torch.mm() 仅 2D × 2D 标准矩阵乘法 torch.bmm() 仅 3D × 3D 批矩阵乘法(batch matrix multiplication)

5. 错误示例

a = torch.randn(3, 4)b = torch.randn(2, 3)torch.matmul(a, b) # ❌ 尺寸不匹配 (4 ≠ 2)

注意:

  • 要保证 matmul(a, b)a 的最后一维 == b 的倒数第二维

6. 常见用途

  • 实现线性层(如:y = x @ W.T + b
  • 注意力机制中的 Q·Kᵀ 操作
  • Transformer 中的矩阵操作
  • 图神经网络中的邻接矩阵乘法
  • batch 处理中的多样本矩阵运算

7. 总结

特性 说明 作用 执行广义矩阵乘法 支持维度 1D、2D、3D、甚至更高维 支持广播 是 常用替代 @ 运算符 推荐用途 适用于所有矩阵乘法场景