【PyTorch】torch.matmul() 函数: 矩阵乘法(矩阵点积)_torch.matmul函数
torch.matmul()
函数详解
torch.matmul()
是 PyTorch 中用于执行 矩阵乘法(矩阵点积) 的函数,支持 1D、2D、3D 及更高维度张量的广义矩阵乘法,是深度学习中非常常用的线性代数运算。
1. 基本语法
torch.matmul(tensor1, tensor2) → Tensor
tensor1
, tensor2
2. 不同维度输入的行为
2.1 两个 1D 张量(向量点积)
a = torch.tensor([1, 2, 3])b = torch.tensor([4, 5, 6])result = torch.matmul(a, b)print(result) # 输出: tensor(32)
计算的是:
1×4+2×5+3×6=321 \\times 4 + 2 \\times 5 + 3 \\times 6 = 321×4+2×5+3×6=32
2.2 1D × 2D 或 2D × 1D(向量和矩阵)
# 1D × 2D,向量乘矩阵a = torch.tensor([1.0, 2.0]) # shape: (2,)b = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) # shape: (2, 2)print(torch.matmul(a, b)) # shape: (2,)# 输出: tensor([13., 16.])
[1,2]×[3456]=[13,16][1, 2] \\times \\begin{bmatrix}3 & 4\\\\5 & 6\\end{bmatrix} = [13, 16][1,2]×[3546]=[13,16]
# 2D × 1D,矩阵乘向量a = torch.tensor([[3.0, 4.0], [5.0, 6.0]]) # shape: (2, 2)b = torch.tensor([1.0, 2.0]) # shape: (2,)print(torch.matmul(a, b)) # shape: (2,)# 输出: tensor([11., 17.])
2.3 两个 2D 张量(标准矩阵乘法)
a = torch.tensor([[1, 2], [3, 4]]) # (2, 2)b = torch.tensor([[5, 6], [7, 8]]) # (2, 2)print(torch.matmul(a, b))# 输出:# tensor([[19, 22],# [43, 50]])
2.4 批量矩阵乘法(3D 或更高)
a = torch.randn(10, 3, 4) # 表示 10 个 (3×4) 的矩阵b = torch.randn(10, 4, 5) # 表示 10 个 (4×5) 的矩阵result = torch.matmul(a, b) # 输出: (10, 3, 5)
广播规则:
对于更高维度的张量,matmul
会按照批次维度进行 广播并执行批矩阵乘法。
3. 与 @
运算符等价
a = torch.randn(2, 3)b = torch.randn(3, 4)# 等价写法result1 = torch.matmul(a, b)result2 = a @ b
4. 与 torch.mm()
、torch.bmm()
区别
torch.matmul()
torch.mm()
torch.bmm()
5. 错误示例
a = torch.randn(3, 4)b = torch.randn(2, 3)torch.matmul(a, b) # ❌ 尺寸不匹配 (4 ≠ 2)
注意:
- 要保证
matmul(a, b)
中a
的最后一维 ==b
的倒数第二维。
6. 常见用途
- 实现线性层(如:
y = x @ W.T + b
) - 注意力机制中的 Q·Kᵀ 操作
- Transformer 中的矩阵操作
- 图神经网络中的邻接矩阵乘法
- batch 处理中的多样本矩阵运算
7. 总结
@
运算符