Megatron-LM代码仓:Megatron-LM
1. FP16参数指定
训练模型要使用fp16时,训练启动参数中指定--fp16
,
对应megatron/arguments.py
中的定义如下:
1 2 group.add_argument('--fp16' , action='store_true' , help ='Run model in fp16 mode.' )
在计算lm-cross-entropy
时默认是使用fp32来计算的,在开启--fp16
选项的前提下可以通过指定--fp16-lm-cross-entropy
来使用fp16计算lm-loss-entropy
,对应megatron/arguments.py
中的定义如下:
1 2 3 group.add_argument('--fp16-lm-cross-entropy' , action='store_true' , help ='Move the cross entropy unreduced loss calculation' 'for lm head to fp16.' )
在megatron中跟fp16还有关系的一个参数是args.fp32_residual_connection
,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。
1 2 3 if args.fp32_residual_connection: assert args.fp16 or args.bf16, \ 'residual connection in fp32 only supported when using fp16 or bf16.'
validate_args
函数用于check参数有效性,fp16相关实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def validate_args (args, defaults={} ): ...... args.params_dtype = torch.float if args.fp16: assert not args.bf16 args.params_dtype = torch.half ...... if args.fp16_lm_cross_entropy: assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' if args.fp32_residual_connection: assert args.fp16 or args.bf16, \ 'residual connection in fp32 only supported when using fp16 or bf16.' ......
如果指定了fp16,这里的args.fp16
为True,对应的args.params_dtype
参数类型为torch.half
。
2.
ParallelAttention模块中fp16计算
2.1 训练部分
ParallelAttention中有self.query_key_value
、self.core_attention
和self.dense
等子模块,fp16对训练的影响会应用在子模块中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class ParallelAttention(MegatronModule): "" "Parallel self-attention layer abstract class. Self-attention layer takes input with size [s, b, h] and returns output of the same size. " "" def __init__(self, init_method, output_layer_init_method, layer_number, attention_type =AttnType.self_attn, attn_mask_type =AttnMaskType.padding): .. . self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, bias =args.add_bias_linear, gather_output =False , init_method =init_method, async_tensor_model_parallel_allreduce =args.async_tensor_model_parallel_allreduce, **_args_to_kwargs()) .. . self.core_attention = CoreAttention(self.layer_number, self.attn_mask_type) .. . self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, bias =args.add_bias_linear, input_is_parallel =True , init_method =output_layer_init_method, skip_bias_add =True , **_args_to_kwargs())
对于self.query_key_value
和self.dense
模块,fp16的设置能过参数中的**_args_to_kwargs()
进行传递。
1 2 3 4 5 6 7 8 9 10 def _args_to_kwargs (): args = get_args() common_kwargs = { "params_dtype" : args.params_dtype, "use_cpu_initialization" : args.use_cpu_initialization, "perform_initialization" : args.perform_initialization, "gradient_accumulation_fusion" : args.gradient_accumulation_fusion, "sequence_parallel_enabled" : args.sequence_parallel, } return common_kwargs
对于self.core_attention
部分,fp16的设置是在CoreAttention
的__init__
中self.fp16 = args.fp16
。
1 2 3 4 5 6 7 8 9 class CoreAttention (MegatronModule ): def __init__ (self , layer_number, attn_mask_type=AttnMaskType .padding ): super (CoreAttention , self ).__init__() args = get_args() self .fp16 = args.fp16 self .bf16 = args.bf16 ...
2.2 推理部分
在ParallelAttention
模块本身中fp16会影响推理部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 class ParallelAttention (MegatronModule ): def __init__ (self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding ): ... self.params_dtype = args.params_dtype ... def _allocate_memory (self, inference_max_sequence_len, batch_size ): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward (self, hidden_states, attention_mask, encoder_output=None , inference_params=None , rotary_pos_emb=None ): ... if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) ...
当指定了fp16以后,在ParallelAttention
模型__init__
初始化时会设置参数类型self.params_dtype
为fp16
在提前分配memory时_allocate_memory
中会用torch.empty
创建用于推理的大buffer,类型是fp16
在指定推理参数inference_params
时,forward函数中会调用_allocate_memory
3. CoreAttention模块中fp16计算
当设了fp16以后,在CoreAttention
的forward计算的input就是fp16类型,在init中设置fp16
flag主要是用于计算中用到的FusedScaleMaskSoftmax
模块的输出结果类型转换。
1 2 3 4 5 6 7 8 9 10 11 12 13 class CoreAttention (MegatronModule ): def __init__ (self, layer_number, attn_mask_type=AttnMaskType.padding ): ... self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) ...
当FusedScaleMaskSoftmax
执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax
;对于不支持的规模时,会调用forward_torch_softmax
进行模拟,输出的类型就根据self.input_in_float16
来进行cast转换。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class FusedScaleMaskSoftmax (nn.Module): ... def forward (self, input , mask ): assert input .dim() == 4 if self.is_kernel_available(mask, *input .size()): return self.forward_fused_softmax(input , mask) else : return self.forward_torch_softmax(input , mask) def forward_torch_softmax (self, input , mask ): if self.input_in_float16 and self.softmax_in_fp32: input = input .float () if self.scale is not None : input = input * self.scale mask_output = self.mask_func(input , mask) if mask is not None else input probs = torch.nn.Softmax(dim=-1 )(mask_output) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: probs = probs.half() else : probs = probs.bfloat16() return probs
4.
ColumnParallelLinear模块中fp16计算
在ColumnParallelLinear
初始化时创建Parameter中的类型直接按params_dtype(即fp16)
来设。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class ColumnParallelLinear (torch.nn.Module):def __init__ (self, ..., params_dtype=torch.float32, ..., ): ... self.weight = Parameter(torch.empty( self.output_size_per_partition, self.input_size, device=torch.cuda.current_device(), dtype=params_dtype)) ... self.bias = Parameter(torch.empty( self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)) ...
5. lm-cross-entropy计算
以gpt2模型为例,在megatron/model/gpt_model.py
文件中的post_language_model_processing
函数,
如果指定了fp16_lm_cross_entropy
,那么在计算cross entropy
时会把output
先转为float32
再进行计算loss。
1 2 3 4 5 if fp16_lm_cross_entropy: assert output.dtype == torch.half loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) else : loss = tensor_parallel.vocab_parallel_cross_entropy(output.float (), labels)