Zhao Dongyu's Blog

A life which is unexamined is not worth living.

0%

FlashMLA 源码分析

今天Deepseek开源 FlashMLA,之前看过一些 MLA 相关知识了,感觉这是一个很好的学习 Cuda 加速的机会,于是实践学习记录一下。

0.准备工作

0.1 实验平台

如FlashMLA所讲:

FlashMLA is an efficient MLA decoding kernel for Hopper GPUs, optimized for variable-length sequences serving.

所以实验平台选择 Hopper 架构的 GPU。

实验平台信息
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
Mon Feb 24 13:56:23 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.216.03 Driver Version: 535.216.03 CUDA Version: 12.4 |
|:-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H20 On | 00000000:16:00.0 Off | 0 |
| N/A 27C P0 74W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H20 On | 00000000:17:00.0 Off | 0 |
| N/A 29C P0 75W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA H20 On | 00000000:40:00.0 Off | 0 |
| N/A 28C P0 74W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA H20 On | 00000000:41:00.0 Off | 0 |
| N/A 28C P0 74W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA H20 On | 00000000:96:00.0 Off | 0 |
| N/A 27C P0 73W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA H20 On | 00000000:97:00.0 Off | 0 |
| N/A 28C P0 72W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA H20 On | 00000000:C0:00.0 Off | 0 |
| N/A 26C P0 72W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA H20 On | 00000000:C1:00.0 Off | 0 |
| N/A 29C P0 73W / 500W | 0MiB / 97871MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| No running processes found |
+---------------------------------------------------------------------------------------+

0.2 安装

git clone https://github.com/deepseek-ai/FlashMLA.git

git submodule update --init csrc/cutlass/

python setup.py install

安装过程log
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# python setup.py install
running install
/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py:79: SetuptoolsDeprecationWarning: setup.py install is deprecated.
!!

********************************************************************************
Please avoid running ``setup.py`` directly.
Instead, use pypa/build, pypa/installer or other
standards-based tools.

See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
********************************************************************************

!!
self.initialize_options()
/usr/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py:79: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!

********************************************************************************
Please avoid running ``setup.py`` and ``easy_install``.
Instead, use pypa/build, pypa/installer or other
standards-based tools.

See https://github.com/pypa/setuptools/issues/917 for details.
********************************************************************************

!!
self.initialize_options()
running bdist_egg
running egg_info
creating flash_mla.egg-info
writing flash_mla.egg-info/PKG-INFO
writing dependency_links to flash_mla.egg-info/dependency_links.txt
writing top-level names to flash_mla.egg-info/top_level.txt
writing manifest file 'flash_mla.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file 'flash_mla.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build/lib.linux-x86_64-cpython-310/flash_mla
copying flash_mla/__init__.py -> build/lib.linux-x86_64-cpython-310/flash_mla
copying flash_mla/flash_mla_interface.py -> build/lib.linux-x86_64-cpython-310/flash_mla
running build_ext
/usr/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py:426: UserWarning: There are no x86_64-linux-gnu-g++ version bounds defined for CUDA version 12.4
warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
building 'flash_mla_cuda' extension
creating /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc
Emitting ninja build file /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] c++ -MMD -MF /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_api.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/data/private/zhaodongyu/FlashMLA/csrc -I/data/private/zhaodongyu/FlashMLA/csrc/cutlass/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c -c /data/private/zhaodongyu/FlashMLA/csrc/flash_api.cpp -o /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_api.o -O3 -std=c++17 -DNDEBUG -Wno-deprecated-declarations -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_mla_cuda -D_GLIBCXX_USE_CXX11_ABI=0
[2/2] /usr/local/cuda/bin/nvcc --generate-dependencies-with-compile --dependency-output /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_fwd_mla_bf16_sm90.o.d -I/data/private/zhaodongyu/FlashMLA/csrc -I/data/private/zhaodongyu/FlashMLA/csrc/cutlass/include -I/usr/local/lib/python3.10/dist-packages/torch/include -I/usr/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.10/dist-packages/torch/include/TH -I/usr/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/local/cuda/include -I/usr/include/python3.10 -c -c /data/private/zhaodongyu/FlashMLA/csrc/flash_fwd_mla_bf16_sm90.cu -o /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_fwd_mla_bf16_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -O3 -std=c++17 -DNDEBUG -D_USE_MATH_DEFINES -Wno-deprecated-declarations -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ -U__CUDA_NO_HALF2_OPERATORS__ -U__CUDA_NO_BFLOAT16_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math --ptxas-options=-v,--register-usage-level=10 -gencode arch=compute_90a,code=sm_90a --threads 32 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=flash_mla_cuda -D_GLIBCXX_USE_CXX11_ABI=0
ptxas info : 192 bytes gmem
ptxas info : Compiling entry function '_ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi160EEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi160EEEv20Flash_fwd_mla_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, 24 bytes cumulative stack size, 640 bytes smem
ptxas info : Compiling entry function '_ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi128EEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi128EEEv20Flash_fwd_mla_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, 24 bytes cumulative stack size, 512 bytes smem
ptxas info : Compiling entry function '_ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi96EEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi96EEEv20Flash_fwd_mla_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, 24 bytes cumulative stack size, 384 bytes smem
ptxas info : Compiling entry function '_ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi64EEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi64EEEv20Flash_fwd_mla_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, 24 bytes cumulative stack size, 256 bytes smem
ptxas info : Compiling entry function '_ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi32EEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash36flash_fwd_splitkv_mla_combine_kernelIN7cutlass10bfloat16_tEflLi512ELi32EEEv20Flash_fwd_mla_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 48 registers, 24 bytes cumulative stack size, 128 bytes smem
ptxas info : Compiling entry function '_ZN5flash28flash_fwd_splitkv_mla_kernelI27Flash_fwd_kernel_traits_mlaILi576ELi64ELi64ELi8EN7cutlass10bfloat16_tELi512EELb0ENS_16SharedStorageMLAIS4_EEEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash28flash_fwd_splitkv_mla_kernelI27Flash_fwd_kernel_traits_mlaILi576ELi64ELi64ELi8EN7cutlass10bfloat16_tELi512EELb0ENS_16SharedStorageMLAIS4_EEEEv20Flash_fwd_mla_params
104 bytes stack frame, 112 bytes spill stores, 124 bytes spill loads
ptxas info : Used 255 registers, 104 bytes cumulative stack size
ptxas info : Compiling entry function '_ZN5flash28flash_fwd_splitkv_mla_kernelI27Flash_fwd_kernel_traits_mlaILi576ELi64ELi64ELi8EN7cutlass10bfloat16_tELi512EELb1ENS_16SharedStorageMLAIS4_EEEEv20Flash_fwd_mla_params' for 'sm_90a'
ptxas info : Function properties for _ZN5flash28flash_fwd_splitkv_mla_kernelI27Flash_fwd_kernel_traits_mlaILi576ELi64ELi64ELi8EN7cutlass10bfloat16_tELi512EELb1ENS_16SharedStorageMLAIS4_EEEEv20Flash_fwd_mla_params
104 bytes stack frame, 116 bytes spill stores, 136 bytes spill loads
ptxas info : Used 255 registers, 104 bytes cumulative stack size
ptxas info : Compiling entry function '_Z23get_mla_metadata_kernel19Mla_metadata_params' for 'sm_90a'
ptxas info : Function properties for _Z23get_mla_metadata_kernel19Mla_metadata_params
24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info : Used 32 registers, 24 bytes cumulative stack size, 32768 bytes smem
x86_64-linux-gnu-g++ -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -shared -Wl,-O1 -Wl,-Bsymbolic-functions /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_api.o /data/private/zhaodongyu/FlashMLA/build/temp.linux-x86_64-cpython-310/csrc/flash_fwd_mla_bf16_sm90.o -L/usr/local/lib/python3.10/dist-packages/torch/lib -L/usr/local/cuda/lib64 -L/usr/lib/x86_64-linux-gnu -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-x86_64-cpython-310/flash_mla_cuda.cpython-310-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/flash_mla
copying build/lib.linux-x86_64-cpython-310/flash_mla/__init__.py -> build/bdist.linux-x86_64/egg/flash_mla
copying build/lib.linux-x86_64-cpython-310/flash_mla/flash_mla_interface.py -> build/bdist.linux-x86_64/egg/flash_mla
copying build/lib.linux-x86_64-cpython-310/flash_mla_cuda.cpython-310-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
byte-compiling build/bdist.linux-x86_64/egg/flash_mla/__init__.py to __init__.cpython-310.pyc
byte-compiling build/bdist.linux-x86_64/egg/flash_mla/flash_mla_interface.py to flash_mla_interface.cpython-310.pyc
creating stub loader for flash_mla_cuda.cpython-310-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/flash_mla_cuda.py to flash_mla_cuda.cpython-310.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying flash_mla.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying flash_mla.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying flash_mla.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying flash_mla.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.flash_mla_cuda.cpython-310: module references __file__
creating dist
creating 'dist/flash_mla-1.0.0+bcb90f2-py3.10-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing flash_mla-1.0.0+bcb90f2-py3.10-linux-x86_64.egg
creating /usr/local/lib/python3.10/dist-packages/flash_mla-1.0.0+bcb90f2-py3.10-linux-x86_64.egg
Extracting flash_mla-1.0.0+bcb90f2-py3.10-linux-x86_64.egg to /usr/local/lib/python3.10/dist-packages
Adding flash-mla 1.0.0+bcb90f2 to easy-install.pth file

Installed /usr/local/lib/python3.10/dist-packages/flash_mla-1.0.0+bcb90f2-py3.10-linux-x86_64.egg
Processing dependencies for flash-mla==1.0.0+bcb90f2
Finished processing dependencies for flash-mla==1.0.0+bcb90f2

0.3 测试结果

  • b:batch size(批大小)。
  • s_q:query 序列长度。
  • mean_sk:每个 batch 平均 key 长度。
  • h_q:query 头数(head 数)。
  • h_kv:key-value 头数。
  • d:head 维度(查询、键的维度)。
  • dv:value 维度。
  • causal:是否启用因果掩码(Causal Masking)。
  • varlen:是否使用可变序列长度。
b s_q mean_sk h_q h_kv d dv Causal Varlen TFLOPS GB/s
128 1 4096 16 1 576 512 True False 0.609 998
128 1 4096 16 1 576 512 True True 0.634 965
128 2 4096 16 1 576 512 True False 0.610 1004
128 2 4096 16 1 576 512 True True 0.634 1001
128 1 4096 32 1 576 512 True False 0.613 1000
128 1 4096 32 1 576 512 True True 0.645 998
128 2 4096 32 1 576 512 True False 0.618 1005
128 2 4096 32 1 576 512 True True 0.622 995
128 1 4096 64 1 576 512 True False 0.621 1002
128 1 4096 64 1 576 512 True True 0.631 965
128 2 4096 64 1 576 512 True False 1.186 539
128 2 4096 64 1 576 512 True True 1.261 528
128 1 4096 128 1 576 512 True False 1.191 537
128 1 4096 128 1 576 512 True True 1.235 523
128 2 4096 128 1 576 512 True False 2.394 282
128 2 4096 128 1 576 512 True True 2.516 276
128 1 8192 16 1 576 512 True False 1.175 1032
128 1 8192 16 1 576 512 True True 1.198 1012
128 2 8192 16 1 576 512 True False 1.173 1037
128 2 8192 16 1 576 512 True True 1.271 1024
128 1 8192 32 1 576 512 True False 1.178 1033
128 1 8192 32 1 576 512 True True 1.147 1012
128 2 8192 32 1 576 512 True False 1.183 1036
128 2 8192 32 1 576 512 True True 1.151 1017
128 1 8192 64 1 576 512 True False 1.190 1030
128 1 8192 64 1 576 512 True True 1.162 1011
128 2 8192 64 1 576 512 True False 2.305 539
128 2 8192 64 1 576 512 True True 2.408 532
128 1 8192 128 1 576 512 True False 2.314 537
128 1 8192 128 1 576 512 True True 2.508 529
128 2 8192 128 1 576 512 True False 4.698 272
128 2 8192 128 1 576 512 True True 4.648 270
测试结果log
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# python tests/test_flash_mla.py
b=128, s_q=1, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 36, n_repeat: 144
0.609 ms, 30 TFLOPS, 998 GB/s
b=128, s_q=1, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 35, n_repeat: 140
0.634 ms, 29 TFLOPS, 965 GB/s
b=128, s_q=2, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 36, n_repeat: 145
0.610 ms, 60 TFLOPS, 1004 GB/s
b=128, s_q=2, mean_sk=4096, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 35, n_repeat: 140
0.634 ms, 60 TFLOPS, 1001 GB/s
b=128, s_q=1, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 36, n_repeat: 145
0.613 ms, 60 TFLOPS, 1000 GB/s
b=128, s_q=1, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 34, n_repeat: 138
0.645 ms, 59 TFLOPS, 998 GB/s
b=128, s_q=2, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 35, n_repeat: 143
0.618 ms, 118 TFLOPS, 1005 GB/s
b=128, s_q=2, mean_sk=4096, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 35, n_repeat: 143
0.622 ms, 117 TFLOPS, 995 GB/s
b=128, s_q=1, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 35, n_repeat: 143
0.621 ms, 118 TFLOPS, 1002 GB/s
b=128, s_q=1, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 35, n_repeat: 141
0.631 ms, 113 TFLOPS, 965 GB/s
b=128, s_q=2, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.186 ms, 123 TFLOPS, 539 GB/s
b=128, s_q=2, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 18, n_repeat: 74
1.261 ms, 121 TFLOPS, 528 GB/s
b=128, s_q=1, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.191 ms, 123 TFLOPS, 537 GB/s
b=128, s_q=1, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 19, n_repeat: 76
1.235 ms, 120 TFLOPS, 523 GB/s
b=128, s_q=2, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 10, n_repeat: 40
2.394 ms, 122 TFLOPS, 282 GB/s
b=128, s_q=2, mean_sk=4096, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 9, n_repeat: 38
2.516 ms, 120 TFLOPS, 276 GB/s
b=128, s_q=1, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.175 ms, 31 TFLOPS, 1032 GB/s
b=128, s_q=1, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 19, n_repeat: 78
1.198 ms, 30 TFLOPS, 1012 GB/s
b=128, s_q=2, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 20, n_repeat: 80
1.173 ms, 62 TFLOPS, 1037 GB/s
b=128, s_q=2, mean_sk=8192, h_q=16, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 18, n_repeat: 74
1.271 ms, 61 TFLOPS, 1024 GB/s
b=128, s_q=1, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.178 ms, 62 TFLOPS, 1033 GB/s
b=128, s_q=1, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 20, n_repeat: 81
1.147 ms, 61 TFLOPS, 1012 GB/s
b=128, s_q=2, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.183 ms, 123 TFLOPS, 1036 GB/s
b=128, s_q=2, mean_sk=8192, h_q=32, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 20, n_repeat: 81
1.151 ms, 121 TFLOPS, 1017 GB/s
b=128, s_q=1, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 19, n_repeat: 79
1.190 ms, 123 TFLOPS, 1030 GB/s
b=128, s_q=1, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 20, n_repeat: 80
1.162 ms, 120 TFLOPS, 1011 GB/s
b=128, s_q=2, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 10, n_repeat: 41
2.305 ms, 127 TFLOPS, 539 GB/s
b=128, s_q=2, mean_sk=8192, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 10, n_repeat: 40
2.408 ms, 125 TFLOPS, 532 GB/s
b=128, s_q=1, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 10, n_repeat: 41
2.314 ms, 126 TFLOPS, 537 GB/s
b=128, s_q=1, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 9, n_repeat: 38
2.508 ms, 125 TFLOPS, 529 GB/s
b=128, s_q=2, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=False
n_warmup: 5, n_repeat: 20
4.698 ms, 124 TFLOPS, 272 GB/s
b=128, s_q=2, mean_sk=8192, h_q=128, h_kv=1, d=576, dv=512, causal=True, varlen=True
n_warmup: 5, n_repeat: 21
4.648 ms, 123 TFLOPS, 270 GB/s

1.学习分析

接下来开始学习分析源码,本文以下面的配置为例:

  • b = 128
  • s = 4096
  • h_q = 32 (TP=4)
  • s_q = 1 (MTP = 1)
  • varlen = False

进入tests/test_flash_mla.pytest_flash_mla() 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) # 初始化一个长度为 b 的张量,每个 batch 样本的 key-value 序列长度均设为 mean_sk。
if varlen: # 如果 varlen 为 True,则为每个样本生成一个服从正态分布 N(mean_sk, mean_sk/2) 的 key-value 序列长度,并确保 ≥ s_q(即 query 长度)。
for i in range(b):
cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)
total_seqlens = cache_seqlens.sum().item() # batch 内所有 key-value 序列长度之和。 524288 = 4096 * 128
mean_seqlens = cache_seqlens.float().mean().int().item()# batch 内 key-value 序列的平均长度。 4096
max_seqlen = cache_seqlens.max().item() # batch 内 key-value 序列的最大长度。 4096
max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 # 对 max_seqlen 进行 256 对齐(向上取整)。 4096
# print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}")

q = torch.randn(b, s_q, h_q, d)
block_size = 64 # 设定块大小为 64(用于分块存储 key-value)。
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) # [128, 64] 64 = 4096 // 64
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) # [8192, 64, 1, 576]
for i in range(b):
blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float("nan")
blocked_v = blocked_k[..., :dv] # 取 blocked_k 的前 dv 维作为 value

从 key 张量中取前 dv 维作为 value,避免额外分配新的 blocked_v 张量,提高内存效率,提高数据局部性。

在注意力计算中,key (K) 的维度 d 可能比 value (V) 的维度 dv 更大:

  • K 主要用于计算注意力权重(与 query (Q) 进行 softmax(QK^T) 计算)。
  • V 仅用于加权求和,所以 dv 可以小于 d,减少计算量。

在目前的例子中,

  • blocked_k 维度 [8192, 64, 1, 576]
  • blocked_v 维度 [8192, 64, 1, 512]

其中,8192 = 128(batch) * 4096(max_seqlen_pad) // 64(block_size), 64 为 block_size

要实现的是如下计算:

图中 V 和 K 中的一部分是浅绿色,表示这部分共享相同的数据。将 QK=P 定义为gemm 1PV=O 定义为gemm 2。本图参考了LingYe.

1.1 基础实现 ref_mla

先看一下基础的 pytorch 的实现,已经熟悉 attention 计算的可以直接跳过这部分~

ref_mla 要实现的就是

$ Attention(Q,K,V) = softmax() V $

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) # 初始化张量,用于存储结果
lse = torch.empty(b, h_q, s_q, dtype=torch.float32)
for i in range(b): # 逐 batch 计算
begin = i * max_seqlen_pad # 计算 Key-Value 范围索引
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention( # 计算 注意力输出 和 softmax 归一化因子(logsumexp)
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q=h_q,
h_kv=h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out, lse

out 维度 [b, s_q, h_q, dv]

lse 维度 [b, h_q, s_q]

只看一个batch

1
2
3
4
5
q[i].transpose(0, 1).shape = torch.Size([32, 1, 576])
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1).shape = torch.Size([1, 4096, 576])
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1).shape = torch.Size([1, 4096, 512])
h_q = 32
h_kv = 1

接下来看 scaled_dot_product_attention 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float() # 之前是 bfloat16 格式,转为 float32 数据格式
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0) # KV扩展,适用于多头注意力
value = value.repeat_interleave(h_q // h_kv, dim=0) # [1, 4096, 512] -> [32, 4096, 512]
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal: # 适用于自回归 Transformer
s_q = query.shape[-2] # 1
s_k = key.shape[-2] # 4096
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) # 生成下三角掩码,保证每个 token 只能关注它之前的 token
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype) # 因为 s_q=1,所以这里全是0,没生效
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1) # 计算lse,用于后续的归一化调整
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) # 转换为概率分布
return attn_weight @ value, lse # 计算注意力加权结果

其中,

attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))

实现的是计算注意力权重

  • 点积注意力计算公式

    \[\text{attn\_weight} = \frac{QK^T}{\sqrt{d_k}}\]

    其中 \(d_k\)query 的最后一维大小,用于 缩放(Scaling) 以防梯度爆炸。

1.2 主角登场 flash_mla

开始步入正题!

整体来看,flash_mla 包含两个函数:

  • get_mla_metadata
    • 负责 token 级别负载均衡
    • 计算 tile_scheduler_metadatanum_splits 信息,用于后续高效计算
  • flash_mla_with_kvcache
    • 负责paged attention计算

1.2.1 负载均衡 get_mla_metadata

一个 batch 里面有很多 seq, 每个 seq 的 len 都不一样,如果启动 kernel 的时候 grid size 设置成 batch size计算就不均衡,所以先算出 seq 的总长然后按照 sm 数量进行均匀的分配。这样就有句子会切断,get_mla_metadata 就记录这些分割点的信息。这些是所有 flashxxx 的通用做法。 来源:刘俊是

1
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

输入:

  • cache_seqlens
    • 即 batch_size 本文中是 torch.Size([128])
  • s_q * h_q // h_kv
    • num_heads_per_head_k = 1 * 32 // 1
  • h_kv
    • num_heads_k = 1

输出:

  • tile_scheduler_metadata
    • (num_sm_parts, TileSchedulerMetaDataSize = 8)
    • torch.Size([78, 8])
  • num_splits
    • batch_size + 1
    • orch.Size([129])
    • 记录第batch_id 的batch在k seqlen 被拆分了几个thread block

这个函数的实现位于 flash_api.cpp ,计算元数据用于GPU加速推理和计算,为 num_heads_k 个头部划分 SM 资源,确保计算负载均衡。

1
2
num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m);
78 = 78 / 1 / cutlass::ceil_div( 32 , 64 )

num_sm_parts为在k seqlen 维度并行的thread block数目(类似flash decoding),撑满 wave 的并行度,提高GPU的利用率。 --from CalebDu

具体计算在 csrc/flash_fwd_mla_kernel.h 的 get_mla_metadata_func函数的 get_mla_metadata_kernel.

要注意的是, tile_scheduler_metadata 的其中一个维度虽然是 TileSchedulerMetaDataSize = 8,但是只用到了其中的5个,设置为8是为了 int4(16B) 对齐。

这5个分别是:每个sm要处理的起始seq idx,起始seq的token idx, 结束seq idx, 结束seq的token idx, 起始的seq是否被分割了。

1
2
3
4
5
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n; // static constexpr int block_size_n = 64;
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
tile_scheduler_metadata1 = now_n_split_idx;

打印 tile_scheduler_metadata 看一下:

sm begin_idx begin_seqlen end_idx end_seqlen n_split_idx unused unused unused
0 0 0 1 2880 0 0 0 0
1 1 2880 3 1344 1 0 0 0
2 3 1344 4 4096 1 0 0 0
3 5 0 6 2880 0 0 0 0
4 6 2880 8 1344 1 0 0 0
5 8 1344 9 4096 1 0 0 0
6 10 0 11 2880 0 0 0 0
7 11 2880 13 1344 1 0 0 0
8 13 1344 14 4096 1 0 0 0
9 15 0 16 2880 0 0 0 0
10 16 2880 18 1344 1 0 0 0
... ... ... ... ... ... 0 0 0
74 123 1344 124 4096 1 0 0 0
75 125 0 126 2880 0 0 0 0
76 126 2880 127 4096 1 0 0 0
77 128 0 127 4096 0 0 0 0

可以看到metadata记录了每个thread block的开始和结束信息。这里有点像之前看的 Marlin gemm 算子的 streamK 的思想,进行了任务的分割,实现不同 patch 的 k seqlen 并行的thread block之间的负载均衡。

别的都好懂,结合源码理解一下 n_split_idx 的意思:

初始payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;

这是每个 SM 部分 需要处理的任务块数量。向上取整保证任务块均匀分配,fixed_overhead_num_blocks是5,用途是 (待补充) 。

在这里我的计算结果是 ceil(8832 / 78) + 5 = 119

第 0 个sm:

  • now_inx = 0 --> begin_idx = 0

  • now_block = 0 --> begin_seqlen = 0 * 64 = 0

  • n_split_idx = 0

  • 进入while:

    • now_blocks = 0, num_blocks = 64 --> now_remain_blocks = 64
    • remain_payload = 119 VS now_remain_blocks + fixed_overhead_num_blocks = 69
      • 足够覆盖,还有剩余 50 --> cum_num_splits + 1 (累积拆分数量)
    • remain_payload = 50 VS now_remain_blocks + fixed_overhead_num_blocks = 69
      • 不够,now_block = remain_payload - fixed_overhead_num_blocks = 45
      • now_n_split_idx++ (这一行需要多拆一次)

第 1 个sm:

  • now_inx = 1 --> begin_idx = 1

  • now_block = 45 --> begin_seqlen = 45 * 64 = 2880

  • n_split_idx = 1(代表这是在处理拆分后的第1部分)

  • 进入while:

    • now_blocks = 45, num_blocks = 64 --> now_remain_blocks = 19 (上一次的还有19个block没处理完)
    • remain_payload = 119 VS now_remain_blocks + fixed_overhead_num_blocks = 24
      • 足够覆盖,还有剩余 95 --> cum_num_splits + 1 (累积拆分数量)
    • remain_payload = 95 VS now_remain_blocks + fixed_overhead_num_blocks = 69
      • 足够覆盖,还有剩余 26 --> cum_num_splits + 1 (累积拆分数量)
    • remain_payload = 26 VS now_remain_blocks + fixed_overhead_num_blocks = 69
      • 不够,now_block = remain_payload - fixed_overhead_num_blocks = 21
      • now_n_split_idx++ (这一行需要多拆一次)

第 2 个sm:

  • now_inx = 2 --> begin_idx = 1

  • now_block = 21 --> begin_seqlen = 21 * 64 = 1344

  • n_split_idx = 1(代表这是在处理拆分后的第1部分)

……

也就是说,这里分了 78 个sm_parts,在这个例子里,每个sm_parts会处理多行,如果有一行没处理完,就要 n_split_idx + 1。也就是说,对于 SM 来说,n_split_idx 表示的是n维度的第几块的索引。

1.2.2 paged attention计算 flash_mla_with_kvcache

回到主线,看flash_mla_with_kvcache 函数

1
2
3
4
flash_mla_with_kvcache(
q, blocked_k, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=causal,
)
shape(value)
q [128, 1, 32, 576] (batch_size, seq_len_q, num_heads_q, head_dim)
blocked_k [8192, 64, 1, 576] (num_blocks, page_block_size, num_heads_k, head_dim)
block_table [128, 64] (batch_size, max_num_blocks_per_seq), torch.int32
cache_seqlens [128] (batch_size), torch.int32
dv 512 Head_dim of v
tile_scheduler_metadata [78, 8] (num_sm_parts, TileSchedulerMetaDataSize), torch.int32
num_splits [129] (batch_size + 1), torch.int32

要注意,这里的 q 已经吸收了kv的变换矩阵,所以后面可以直接对合并的 kvcache 计算

在 flash_api.cpp 里,mha_fwd_kvcache_mla -> run_mha_fwd_splitkv_mla -> run_flash_splitkv_fwd_mla 函数 lunch 了两个 kernel:

  • flash_fwd_splitkv_mla_kernel
  • flash_fwd_splitkv_mla_combine_kernel

这两个 kernel 是同一个 stream,是顺序执行的一个关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = sizeof(SharedStorage);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();

dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}

逐句解析一下:

  1. 通过计算 num_m_block 计算在 M 维度上所需要的块数, 计算是在M方向切分 block,在 seq len 方向进行 n_block 的loop。其中,blockM=64,blockN=64。

num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);

在我的例子中,h_q(即 params.seqlen_q) = 32,不够 blockM, 所以计算出来的 num_m_block = 1.

  1. flash_fwd_splitkv_mla_kernel 这个内核函数会根据 Kernel_traits 和 Is_causal 的值进行特化。

  2. smem_size: 计算共享内存的大小,SharedStorage 是一个用于存储共享数据的结构体,大小通过 sizeof 获取。然后通过 调用 cudaFuncSetAttribute 函数为内核设置最大动态共享内存的大小。

这里值得仔细算一下:smem_size = 230400 (224KB)

union{
    struct{
        - smem_q: 73728 
            - 存放输入Q
            - 576 * 64 * 2B / 1024 = 72 KB
        - smem_k: 73728 * 2 = 147456 (Double buffer)
            - 存放输入K(包含部分V)
            - 64 * 576 * 2B * 2 / 1024 = 144KB
        - smem_p: 8192
            - 用于存放 gemm 1 的结果,用于 `wg 0` 和 `wg 1` 之间的数据中转
            - 2 x 2 x 128 x 8 x 2B / 1024 = 8KB
        - smem_scale: 1024
    }
    struct{
        - smem_max: 1024
        - smem_sum: 1024
        - smem_o  : 131072
    }
}
Data Center GPU NVIDIA V100 NVIDIA A100 NVIDIA H100
GPU architecture NVIDIA Volta NVIDIA Ampere NVIDIA Hopper
Compute capability 7.0 8.0 9.0
Shared memory size / SM Configurable up to 96 KB Configurable up to 164 KB Configurable up to 228 KB

表格数据来源

smem_size 是 224KB,而 Hopper 架构的 Shared memory size 是228KB, 这么看的话,确实只有 Hopper 架构的卡能完美跑这个算法了。

  1. kernel<<<...>>>: 启动 CUDA 核函数 kernel,设置其 grid size 和 block size。
  • dim3(num_m_block, params.h, params.num_sm_parts) 是执行的 grid 大小
    • dim3(1, num_heads = 1, 78)
  • Kernel_traits::kNThreads 是每个线程块中的线程数量
    • 256
  • smem_size 是共享内存大小
    • 230400
  • stream 是 CUDA 流
  1. 因为 params.num_sm_parts 是 78,所以在 MLA_NUM_SPLITS_SWITCH 中会将 kMaxSplits 设置为 96。
1
2
3
4
5
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});

启动 CUDA 核函数 combine_kernel

  • combine_kernel<<<grid_combine, 128, 0, stream>>>
    • grid_combine: 是一个一维的 dim3 类型,表示合并内核的网格大小。
    • 只给出一个数字时,CUDA 会默认将它映射到 x 维度,并将 y 和 z 维度设置为 1
    • dim3(batch_size * num_heads * seqlen_q = 128 * 1 * 32 = 4096, 1, 1)
    • 这个 combine_kernel 没有分配额外的动态共享内存

接下来看这两个kernel。

1.2.2.1 flash_fwd_splitkv_mla_kernel

  • dim3(num_m_block, params.h, params.num_sm_parts)
  • dim3(1, num_heads = 1, 78)
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit);
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
}

我们聚焦在 SM1 上,对应的 begin_idx = 1, end_idx = 3, 于是会进行3次循环:

batch_id n_split_idx seqlen_k n_block_min n_block_max NoSplit
1 1 4096 45 64 0
2 0 4096 0 64 1
3 0 4096 0 21 0

进一步看 compute_attn_1rowblock_splitkv_mla 函数:

首先是一些参数:

  • kBlockM = 64
  • kBlockN = 64
  • kHeadDim = 576
  • kHeadDimV = 512
  • kNThreads = 256
  • kNThreadsS = 128
1
2
3
4
5
6
7
8
9
10
/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads.
/// Threads within the warp must be converged.
CUTLASS_DEVICE
int canonical_warp_group_idx() {
#if defined(__CUDA_ARCH__)
return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); // NumThreadsPerWarpGroup = 128
#else
return 0;
#endif
}

使用 cutlass::canonical_warp_group_idx(); 函数对每个thread block 的 256 thread 进行了分组,分成了2个warp group。也就是 0-127 thread是 wg0, 128-255 thread是wg1.

接下来的代码逻辑,画图反而表达得比较清楚一些:

warp group 0:

1
2
3
4
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)

为线程分配矩阵片段,以便进行矩阵乘法(MMA)操作

1
2
3
4
5
6
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}

这里是 Double buffer 逻辑,如果 n_block 是奇数的话,要加上 sK_offset / 8 的偏移。

  • 为什么是 sK_offset / 8 呢?我理解的是单个wgmma为64x64x16,这里的 8 是 64 * 2 / 16。
1
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {

for循环是沿着 n_block 做了一个遍历

每一次循环分别做了 gemm1 的计算,mask操作,online softmax 以及存储计算结果的操作。

tiled_mma shape是64,64,576,单个wgmma为 64x64x16,k方向循环迭代。

tiled_mma_o shape是64,512,64, 在 N 的方向切成2个mma,单个wgmma为64x256x16,warp group 0计算其中的一部分,

warp group 1:

wg1 负责加载 Q K P,做 tiled_mma_o 的第二部分计算。

有意思的是,wgmma最大支持N=256,刚好是headdimV的一半,因此两个warp group刚好完成一整个gemm 2的计算。LingYe

加载这个动作用 block_table 进行索引,通过 n_block 的奇偶性切换 KV 的 buffer。

在 n_block loop 结束完成后,通过 SoftmaxReady 做了一个 sum/max 的同步,让两个warp group都取得相同的数据,最后一起做stroe output。

MLA kernel 之后还有combine kernel,去 reduce num_sm_parts partial result 得到完整结果。

1.2.3 flash_fwd_splitkv_mla_combine_kernel

  • dim3(batch_size * num_heads * seqlen_q = 128 * 1 * 32 = 4096, 1, 1)

这个 kernel 处理基于 splitkv 的计算,执行加法操作、softmax 求和以及其他张量操作。

1
2
3
4
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
if (actual_num_splits == 1) return;

通过从 params.num_splits_ptr 中读取当前批次的 split 数量,并检查 actual_num_splits 是否大于 kMaxSplits。 如果 actual_num_splits == 1,则提前返回,不做计算。

这要又一次结合这张图来看了

  • batch_idx = 0 --> actual_num_splits = 1 --> return
  • batch_idx = 1 --> actual_num_splits = 2 --> ...
  • batch_idx = 2 --> actual_num_splits = 1 --> return
  • batch_idx = 3 --> actual_num_splits = 3 --> ...

这一块代码进行了 LSE 的计算:计算最大值,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
int warp_idx = cutlass::canonical_warp_idx_sync(); // 32个线程一个warp
if (warp_idx == 0) { // 只有 0-31 线程
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);

float local_lse[kNLsePerThread];
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
}

float max_lse = -INFINITY;
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf

float sum_lse = 0;
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);

float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
if (tidx == 0) gLSE(0) = global_lse;

for (int i = 0; i < kNLsePerThread; ++i) { //存储每个split的lse缩放系数到共享内存中
const int split = i * 32 + tidx;
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
__syncthreads();

对每个split进行加权求和,结果在 tOrO 张量中

1
2
3
4
5
6
7
8
for (int split = 0; split < actual_num_splits; ++split) {
cute::copy(tOgOaccum, tOrOaccum);
ElementAccum lse_scale = sLseScale[split];
for (int i = 0; i < size(tOrO); ++i) {
tOrO(i) += lse_scale * tOrOaccum(i);
}
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
}

最终对结果做一下类型转换和存储到 global memory 中。

这个 kernel 主要就是在每个线程块内计算和合并多个 split 的结果,使用 softmax 的 LSE 计算和缩放系数进行加权求和。

2.性能查看

使用 nsight system 简单看一下 kernel 的耗时:

nsys profile --trace=cuda,osrt -o flash_mla --force-overwrite true python tests/test_flash_mla.py

和MLA有什么关系?

之前看了一些关于讲解 MLA 的文章,如 苏剑林 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLAZHANG Mingxing DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子

看算子的时候就一直在好奇,这不就是在算 attention 吗?W^UQ W^UK 这些投影矩阵完全没出现啊,那这和 MLA 有什么关系呢?

后来才知道,这些已经被矩阵吸收到 Q 矩阵了,这个 FlashMLA 就是在做针对 MLA 维度的高效推理。

总结

MLA本质上是一个KV部分共享的,升维的MQA。维度从一般的128升维到576/512,KV共享前512长度,另有64长度是K独有的。 flashMLA只能在hopper架构上运行,几乎无法移植到其他平台。(除非有大于228KB的share memory,有N大于256的wgmma) flashMLA利用了online softmax算法、paged attn的分块和split-kv优化,叠加自己的计算mapping和两个warp group相配合的流水线,达到了很高的性能。LingYe

ZHANG Mingxing 的 git 里还有关于 move_elision 的优化,效果显著,但是我看 FlashMLA 没有用到,不知道用上的话会有什么影响。

认真读完代码之后学习到很多 Cuda 算子设计的巧妙之处,受益良多,接下来开始尝试将 FlashMLA 用起来,以及学习 Deepseek 开源的其他工程。

小白一枚,理解浅显,有问题的部分还希望大佬们指正~

参考

FlashMLA

flashMLA 深度解析

CalebDu

DeepSeek: FlashMLA代码解析

Deepseek FlashMLA解析

怎样评价NVIDIA新一代的Hopper GPU架构?

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA

用cutlass cute实现flash attention

deepseekv2-profile

Thanks for your support.