Megatron-LM源码系列(五): FP16使用

Megatron-LM代码仓:Megatron-LM

1. FP16参数指定

  • 训练模型要使用fp16时,训练启动参数中指定--fp16, 对应megatron/arguments.py中的定义如下:
1
2
group.add_argument('--fp16', action='store_true',
help='Run model in fp16 mode.')
  • 在计算lm-cross-entropy时默认是使用fp32来计算的,在开启--fp16选项的前提下可以通过指定--fp16-lm-cross-entropy来使用fp16计算lm-loss-entropy,对应megatron/arguments.py中的定义如下:
1
2
3
group.add_argument('--fp16-lm-cross-entropy', action='store_true',
help='Move the cross entropy unreduced loss calculation'
'for lm head to fp16.')
  • 在megatron中跟fp16还有关系的一个参数是args.fp32_residual_connection,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。
1
2
3
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
  • validate_args函数用于check参数有效性,fp16相关实现如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
def validate_args(args, defaults={}):
......
args.params_dtype = torch.float
if args.fp16:
assert not args.bf16
args.params_dtype = torch.half
......
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
if args.fp32_residual_connection:
assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.'
......

如果指定了fp16,这里的args.fp16为True,对应的args.params_dtype参数类型为torch.half

2. ParallelAttention模块中fp16计算

2.1 训练部分

ParallelAttention中有self.query_key_valueself.core_attentionself.dense等子模块,fp16对训练的影响会应用在子模块中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.

Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""

def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
...
self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type)
...
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
bias=args.add_bias_linear,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs())

对于self.query_key_valueself.dense模块,fp16的设置能过参数中的**_args_to_kwargs()进行传递。

1
2
3
4
5
6
7
8
9
10
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs

对于self.core_attention部分,fp16的设置是在CoreAttention__init__self.fp16 = args.fp16

1
2
3
4
5
6
7
8
9
class CoreAttention(MegatronModule):

def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
...

2.2 推理部分

ParallelAttention模块本身中fp16会影响推理部分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class ParallelAttention(MegatronModule):
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
...
self.params_dtype = args.params_dtype
...

def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())

def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None,
rotary_pos_emb=None):
...
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
...
  • 当指定了fp16以后,在ParallelAttention模型__init__初始化时会设置参数类型self.params_dtype为fp16
  • 在提前分配memory时_allocate_memory中会用torch.empty创建用于推理的大buffer,类型是fp16
  • 在指定推理参数inference_params时,forward函数中会调用_allocate_memory

3. CoreAttention模块中fp16计算

当设了fp16以后,在CoreAttention的forward计算的input就是fp16类型,在init中设置fp16 flag主要是用于计算中用到的FusedScaleMaskSoftmax模块的输出结果类型转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
class CoreAttention(MegatronModule):

def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
...
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
...

FusedScaleMaskSoftmax执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax;对于不支持的规模时,会调用forward_torch_softmax进行模拟,输出的类型就根据self.input_in_float16来进行cast转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class FusedScaleMaskSoftmax(nn.Module):
...
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4

if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)

def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()

if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()

return probs

4. ColumnParallelLinear模块中fp16计算

ColumnParallelLinear初始化时创建Parameter中的类型直接按params_dtype(即fp16)来设。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class ColumnParallelLinear(torch.nn.Module):
def __init__(self, ...,
params_dtype=torch.float32,
...,
):
...
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
...
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
...

5. lm-cross-entropy计算

以gpt2模型为例,在megatron/model/gpt_model.py文件中的post_language_model_processing函数, 如果指定了fp16_lm_cross_entropy,那么在计算cross entropy时会把output先转为float32再进行计算loss。

1
2
3
4
5
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)