> 技术文档 > triton学习笔记3:矩阵分块术

triton学习笔记3:矩阵分块术


Puzzle 10: Two Dimensional Convolution

A batched 2D convolution.

Uses one program id axis. Block size B0 represent the batches to process out of N0.
Image x is size is H by W with only 1 channel, and kernel k is size KH by KW.

… math::
z_{i, j, l} = \\sum_{oj, ol}^{j+oj\\le H, l+ol\\le W} k_{oj,ol} \\times x_{i,j + oj, l + ol}
\\text{ for } i = 1\\ldots N_0 \\text{ for } j = 1\\ldots H \\text{ for } l = 1\\ldots W
“”\"

def conv2d_spec(x: Float32[4, 8, 8], k: Float32[4, 4]) -> Float32[4, 8, 8]: z = torch.zeros(4, 8, 8) x = torch.nn.functional.pad(x, (0, 4, 0, 4, 0, 0), value=0.0) # print(x.shape, k.shape) for i in range(8): for j in range(8): z[:, i, j] = (k[None, :, :] * x[:, i : i + 4, j : j + 4]).sum(1).sum(1) return z
@triton.jitdef conv2d_kernel( x_ptr, k_ptr, z_ptr, N0, H, W, KH: tl.constexpr, KW: tl.constexpr, B0: tl.constexpr): # Finish me! \"\"\" @triton.jit 实现的2D卷积核函数参数: x_ptr: 输入张量指针 k_ptr: 卷积核指针 z_ptr: 输出张量指针 N0: 批量大小 H: 输入高度 W: 输入宽度 KH: 卷积核高度(编译时常量) KW: 卷积核宽度(编译时常量) B0: 块大小(编译时常量)功能: 对输入张量执行2D卷积操作,结果存入输出张量 使用分块并行处理提高性能\"\"\"pid_0 = tl.program_id(0)off_i = pid_0 * B0 + tl.arange(0, B0)mask_i = off_i < N0off_h = tl.arange(0, KH)off_w = tl.arange(0, KW)off_hw = off_h[:,None] * KW + off_w[None,:]conv_kernel = tl.load(k_ptr + off_hw)for j in tl.range(0,H): for l in tl.range(0, W): off_j_oj = j + off_h[None, :, None] off_l_ol = l + off_w[None, None, :] off_x = off_i * H * W + off_j_oj * W + off_l_ol mask_x = (off_j_oj < H) & (off_l_ol < W) x = tl.load(x_ptr + off_x, mask=mask_x) z = tl.sum(x * conv_kernel[None, :]) off_z = off_i * H * W + j * W + l tl.store(z_ptr + off_z, z)return

unittest

  • Test basic convolution operation with small input tensor and kernel

    import unittestimport torchimport tritonfrom puzzles import conv2d_kernelclass PuzzlesTest(unittest.TestCase): def test_basic_functionality_with_small_inputs(self): \"\"\"Test basic convolution operation with small input tensor and kernel\"\"\" # Input parameters N0 = 1 H = 3 W = 3 KH = 2 KW = 2 B0 = 1 # Create input tensor and kernel with simple values x = torch.tensor([ [1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0] ], device=\'cuda\').reshape(1, H, W) k = torch.tensor([ [1.0, 0.0], [0.0, 1.0] ], device=\'cuda\') # Expected output (manually computed convolution) expected_z = torch.tensor([ [1*1 + 2*0 + 4*0 + 5*1, 2*1 + 3*0 + 5*0 + 6*1], [4*1 + 5*0 + 7*0 + 8*1, 5*1 + 6*0 + 8*0 + 9*1] ], device=\'cuda\').reshape(1, 2, 2) # Allocate output tensor z = torch.empty((N0, H - KH + 1, W - KW + 1), device=\'cuda\') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Define grid grid = lambda meta: (triton.cdiv(N0, meta[\'B0\']),) # Run the kernel conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0) # Check if results match self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3), f\"Expected:\\n{expected_z}\\nGot:\\n{z}\")if __name__ == \'__main__\': unittest.main()
  • Test when N0 is exactly divisible by B0

    import unittestimport torchimport tritonfrom puzzles import conv2d_kernelclass PuzzlesTest(unittest.TestCase): def test_conv2d_kernel_full_block_processing(self): \"\"\"Test conv2d_kernel when N0 is exactly divisible by B0\"\"\" # Input parameters N0 = 32 H = 5 W = 5 KH = 3 KW = 3 B0 = 32 # Create random input tensors x = torch.randn(N0, H, W, device=\'cuda\') k = torch.randn(KH, KW, device=\'cuda\') z = torch.empty(N0, H, W, device=\'cuda\') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Define grid function grid = lambda meta: (triton.cdiv(N0, meta[\'B0\']),) # Launch the kernel conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0) # Compute expected output using PyTorch\'s conv2d # Note: We need to reshape and permute dimensions to match conv2d expectations x_4d = x.view(N0, 1, H, W) k_4d = k.view(1, 1, KH, KW) expected_z = torch.nn.functional.conv2d( x_4d, k_4d, padding=(KH//2, KW//2) ).squeeze(1) # Check if results match self.assertTrue(torch.allclose(z, expected_z, rtol=1e-3, atol=1e-3), \"Output does not match expected result\")if __name__ == \'__main__\': unittest.main()
  • Test with random input values

    import unittestimport torchimport tritonfrom puzzles import conv2d_kernelfrom tensor_type import Float32class PuzzlesTest(unittest.TestCase): def test_conv2d_kernel_random_values(self): \"\"\"Test conv2d_kernel with random input values\"\"\" # Setup test parameters N0 = 4 H = 10 W = 10 KH = 4 KW = 4 B0 = 2 # Generate random inputs torch.manual_seed(0) x = torch.rand((N0, H, W), device=\'cuda\') - 0.5 k = torch.rand((KH, KW), device=\'cuda\') - 0.5 # Allocate output tensor z = torch.empty((N0, H, W), device=\'cuda\') # Convert to pointers x_ptr = x.data_ptr() k_ptr = k.data_ptr() z_ptr = z.data_ptr() # Compute reference result using PyTorch\'s conv2d # Note: We need to reshape and permute dimensions to match conv2d expectations x_4d = x.unsqueeze(1) # Add channel dimension k_4d = k.unsqueeze(0).unsqueeze(0) # Add out_channels and in_channels dimensions z_ref = torch.nn.functional.conv2d( x_4d, k_4d, padding=(KH//2, KW//2) ).squeeze(1) # Remove channel dimension # Run the kernel grid = lambda meta: (triton.cdiv(N0, meta[\'B0\']),) conv2d_kernel[grid](x_ptr, k_ptr, z_ptr, N0, H, W, KH, KW, B0=B0) # Verify results self.assertTrue(torch.allclose(z, z_ref, rtol=1e-3, atol=1e-3), \"Convolution results do not match reference\")if __name__ == \'__main__\': unittest.main()

r\"“”

Puzzle 11: Matrix Multiplication

A blocked matrix multiplication.

Uses three program id axes. Block size B2 represent the batches to process out of N2.
Block size B0 represent the rows of x to process out of N0. Block size B1 represent the cols
of y to process out of N1. The middle shape is MID.

… math::
z_{i, j, k} = \\sum_{l} x_{i,j, l} \\times y_{i, l, k} \\text{ for } i = 1\\ldots N_2, j = 1\\ldots N_0, k = 1\\ldots N_1

You are allowed to use tl.dot which computes a smaller mat mul.

Hint: the main trick is that you can split a matmul into smaller parts.

… math::
z_{i, j, k} = \\sum_{l=1}^{L/2} x_{i,j, l} \\times y_{i, l, k} + \\sum_{l=L/2}^{L} x_{i,j, l} \\times y_{i, l, k}
“”\"

def dot_spec(x: Float32[4, 32, 32], y: Float32[4, 32, 32]) -> Float32[4, 32, 32]: return x @ y
@triton.jitdef dot_kernel( x_ptr, y_ptr, z_ptr, N0, N1, N2, MID, B0: tl.constexpr, B1: tl.constexpr, B2: tl.constexpr, B_MID: tl.constexpr,): block_id_j = tl.program_id(0) block_id_k = tl.program_id(1) block_id_i = tl.program_id(2) # Finish me! off_i = block_id_i * B2 + tl.arange(0, B2) off_j = block_id_j * B0 + tl.arange(0, B0) off_k = block_id_k * B1 + tl.arange(0, B1) mask_i = off_i < N2 mask_j = off_j < N0 mask_k = off_k < N1 z = tl.zeros((B2, B0, B1), dtype=tl.float32) off_z = off_i[:,None,None] * N0 * N1 & off_j[None,:,None] * N1 & off_k[None,None,:] mask_z = mask_i[:,None,None] & mask_j[None,:,None] & mask_k[None,None,:] for i in range(0, MID, B_MID): off_b = i + tl.arange(0, B_MID) mask_b = off_b < MID off_x = off_i[:,None,None] * N0 * MID + off_j[None,:,None] * MID + off_b[None,None,:] off_y = off_i[:,None,None] * MID * N1 + off_b[None,:,None] * N1 + off_k[None,None,:] mask_x = mask_i[:,None,None] & mask_j[None,:,None] & mask_b[None,None,:] mask_y = mask_i[:,None,None] & mask_b[None,:,None] & mask_k[None,None,:] x = tl.load(x_ptr + off_x, mask=mask_x) y = tl.load(y_ptr + off_y, mask=mask_y) z = tl.dot(x,y, allow_tf32=False) tl.store(z_ptr + off_z, z, mask=mask_z) return

Puzzle 12: Quantized Matrix Mult

When doing matrix multiplication with quantized neural networks a common strategy is to store the weight matrix in lower precision, with a shift and scale term.

For this problem our weight will be stored in 4 bits. We can store FPINT of these in a 32 bit integer. In addition for every group weights in order we will store 1 scale float value and 1 shift 4 bit value. We store these for the column of weight. The activations are stored separately in standard floats.

Mathematically it looks like.

… math::
z_{j, k} = \\sum_{l} sc_{j, \\frac{l}{g}} (w_{j, l} - sh_{j, \\frac{l}{g}}) \\times y_{l, k}
\\text{ for } j = 1\\ldots N_0, k = 1\\ldots N_1

Where g is the number of groups (GROUP).

However, it is a bit more complex since we need to also extract the 4-bit values into floats to begin.

Note:

  • We don’t consider batch size, i.e. i, in this puzzle.
  • Remember to unpack the FPINT values into separate 4-bit values. This contains some shape manipulation.
    “”\"
FPINT = 32 // 4GROUP = 8def quant_dot_spec( scale: Float32[32, 8], offset: Int32[32,], weight: Int32[32, 8], activation: Float32[64, 32],) -> Float32[32, 32]: offset = offset.view(32, 1)def extract(x): over = torch.arange(8) * 4 mask = 2**4 - 1 return (x[..., None] >> over) & maskscale = scale[..., None].expand(-1, 8, GROUP).contiguous().view(-1, 64)offset = ( extract(offset)[..., None].expand(-1, 1, 8, GROUP).contiguous().view(-1, 64))return (scale * (extract(weight).view(-1, 64) - offset)) @ activation@triton.jitdef quant_dot_kernel( scale_ptr, offset_ptr, weight_ptr, activation_ptr, z_ptr, N0, N1, MID, B0: tl.constexpr, B1: tl.constexpr, B_MID: tl.constexpr,): # Finish me! block_id_j = tl.program_id(0) block_id_k = tl.program_id(1) # Finish me! off_j = block_id_j * B0 + tl.arange(0, B0) off_k = block_id_k * B1 + tl.arange(0, B1) mask_j = off_j < N0 mask_k = off_k < N1 z = tl.zeros((B0, B1), dtype=tl.float32) off_z = off_j[:, None] * N1 + off_k[None, :] mask_z = mask_j[:, None] & mask_k[None, :] for l in tl.range(0, MID, B_MID): # load scale off_l_div_g = tl.arange(0, B_MID // GROUP) + (l // GROUP) mask_l_div_g = off_l_div_g < (MID // GROUP) off_scale = off_j[:, None] * (MID // GROUP) + off_l_div_g[None, :] # print(off_scale.shape) mask_scale = mask_j[:, None] & mask_l_div_g[None, :] scale = tl.load(scale_ptr + off_scale, mask=mask_scale) # load shift (offset) # (32,), each 32bits integer store FPINT(8)*4 shifts shift = tl.load(offset_ptr + off_j, mask=mask_j) # load weight # note: our weight will be stored in 4bits. off_weight_l = l + tl.arange(0, B_MID // FPINT) mask_weight_l = off_weight_l < (MID // FPINT) off_weight = off_j[:, None] * (MID // FPINT) + off_weight_l[None, :] mask_weight = mask_j[:, None] & mask_weight_l[None, :] weight = tl.load(weight_ptr + off_weight, mask=mask_weight) # load activation as normal float off_l = l + tl.arange(0, B_MID) mask_l = off_l < MID off_activation = off_l[:, None] * N1 + off_k[None, :] mask_activation = mask_l[:, None] & mask_k[None, :] activation = tl.load(activation_ptr + off_activation, mask=mask_activation) # unpack weight and shift BITS = 32 // FPINT unpack_offs = tl.arange(0, FPINT) * BITS unpack_upperbound_mask = (1 << BITS) - 1 unpacked_shift = (shift[:, None] >> unpack_offs) & unpack_upperbound_mask unpacked_weight = (weight[:, :, None] >> unpack_offs) & unpack_upperbound_mask # quant transform # [BLOCK_J, 8, 1] * ([BLOCK_J, 8, 8] - [BLOCK_J, 8, 1]) transformed_weight = scale[:, :, None] * ( unpacked_weight - unpacked_shift[:, :, None] ) # shape: [*, 64] transformed_weight = transformed_weight.reshape( unpacked_shift.shape[0], unpacked_shift.shape[-1] * FPINT ) # compute z += tl.dot(transformed_weight, activation) tl.store(z_ptr + off_z, z, mask=mask_z) return

Reference

  1. https://zhuanlan.zhihu.com/p/672086654

  2. https://gitcode.com/gh_mirrors/tr/Triton-Puzzles-Lite