Megatron-LM源码系列(八): Context Parallel并行
1. Context Parallel并行原理介绍
megatron中的context并行(简称CP
)与sequence并行(简称SP
)不同点在于,SP只针对Layernorm
和Dropout
输出的activation在sequence维度上进行切分,CP则是对所有的input输入和所有的输出activation在sequence维度上进行切分,可以看成是增强版的SP。除了Attention模块以外,其他的模块(Layernorm、Dropout)由于没有多token的处理,在CP并行时都不用任何修改。
为什么Attention模块是个例外?
因为Attention计算过程中每个token的Q(query)要跟同一个sequence中其他token的K(key)和V(value)一起进行计算,存在计算上的依赖,所以通过CP并行后,在计算Attention前要通过allgather
通信拿到所有token的KV向量,在反向计算时对应需要通过reduce_scatter
分发gradient梯度。
为了减少显存占用,在前向时每个gpu只用保存一部分KV块,反向时通过allgather
通信拿到所有的KV数据。KV的通信发生在相邻TP通信组相同位置的rank之间。allgather和reduce_scatter在ring拓扑架构实现时,底层会通过send和recv来进行实现。
以上图TP2-CP2的transformer网络为例,在Attention前的是CP的通信算子,其他都是TP的通信算子。AG
表示allgather,
RS
表示reduce_scatter,
AG/RS
表示前向allgather反向reduce_scatter,
RS/AG
表示前向reduce_scatter反向allgather。
这里TP2对应为[GPU0, GPU1], [GPU2, GPU3],
CP2对应为TP组相同位置的rank号,也就是[GPU0, GPU2], [GPU1,
GPU3]。CP并行与Ring
Attention类似,但是提供了新的OSS与FlashAttention版本,也去除了low-triangle causal masking
的冗余计算。
LLM经常由于sequence长度过长导致显存OOM,这时之前的一种方式是通过重计算的方式保存中间的activation产出,全量重计算的劣势会带来30%的计算代价;另外一种方式是扩大TP(tensor parallel)的大小,扩大TP的劣势在于会对tensor切的更小,从而导致linear fc的计算时间变少,从而与通信很难进行计算的掩盖。
通过CP可以更好解决OOM的问题,每个GPU只用处理一部分的sequence,
同时减少CP倍的通信和计算,但保持TP不变,同时activation也会减少CP倍。CP优化的性能参考如下图,在Megatron中(Megatron-Core>=0.5.0
&& Transformer Engine
>=1.1)通过指定--context-parallel-size
可以进行使用。\(total\_gpu\_count = TP \times CP \times PP \times
DP\)。
2. 源码
以Megatron-Core 0.5.0为例进行介绍
- 首先在
megatron/arguments.py
中定义了--context-parallel-size
参数, 同时也要求了world_size能要整除TP*PP*CP
。
1 | group.add_argument('--context-parallel-size', type=int, default=1, |
- 在
megatron/core/parallel_state.py
中初始化通信组时会初始化相关CP通信组, 以TP-PP-DP-CP=8-1-1-2为例,TP通信组为[0,1,2,3,4,5,6,7],[8,9,10,11,12,13,14,15]
, CP通信组为[0,8],[1,9],[2,10],[3,11],[4,12],[5,13],[6,14],[7,15]
。
1 | def initialize_model_parallel(...): |
- 在
megatron/core/transformer/custom_layers/transformer_engine.py
中TEDotProductAttention
会初始化相关CP通信组相关参数,TEDotProductAttention
继承自te.pytorch.DotProductAttention
,在前向中直接调用父类的的forward函数。
1 | class TEDotProductAttention(te.pytorch.DotProductAttention): |
- Transformer
Engine中
DotProductAttention
定义在transformer_engine/pytorch/attention.py
中,CP相关参数通过attn_kwargs进行传入。接下来会调用FlashAttention
的前反向。
1 | class DotProductAttention(torch.nn.Module): |
- 在FlashAttention中会通过函数
attn_forward_func_with_cp
进行调用,最终Attn前的all_gather通信是在AttnFuncWithCP
中通过send、recv通信来实现的, 执行完通信就执行对应的flash_attention算子的调用。
1 | def attn_forward_func_with_cp(...): |