Megatron-LM代码仓:Megatron-LM 
1. FP16参数指定 
训练模型要使用fp16时,训练启动参数中指定--fp16,
对应megatron/arguments.py中的定义如下: 
 
1 2 group.add_argument('--fp16' , action='store_true' ,                    help ='Run model in fp16 mode.' ) 
在计算lm-cross-entropy时默认是使用fp32来计算的,在开启--fp16选项的前提下可以通过指定--fp16-lm-cross-entropy来使用fp16计算lm-loss-entropy,对应megatron/arguments.py中的定义如下: 
 
1 2 3 group.add_argument('--fp16-lm-cross-entropy' , action='store_true' ,                    help ='Move the cross entropy unreduced loss calculation'                     'for lm head to fp16.' ) 
在megatron中跟fp16还有关系的一个参数是args.fp32_residual_connection,这里设置了的话会在计算残差连接的时候转为fp32再进行计算,这里残差连接在网络中对应是Embedding模块。 
 
1 2 3 if  args.fp32_residual_connection:    assert  args.fp16 or  args.bf16, \         'residual connection in fp32 only supported when using fp16 or bf16.'  
validate_args函数用于check参数有效性,fp16相关实现如下: 
1 2 3 4 5 6 7 8 9 10 11 12 13 14 def  validate_args (args, defaults={} ):    ......     args.params_dtype = torch.float      if  args.fp16:         assert  not  args.bf16         args.params_dtype = torch.half     ......          if  args.fp16_lm_cross_entropy:         assert  args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'      if  args.fp32_residual_connection:         assert  args.fp16 or  args.bf16, \             'residual connection in fp32 only supported when using fp16 or bf16.'      ...... 
如果指定了fp16,这里的args.fp16为True,对应的args.params_dtype参数类型为torch.half。
2.
ParallelAttention模块中fp16计算 
2.1 训练部分 
ParallelAttention中有self.query_key_value、self.core_attention和self.dense等子模块,fp16对训练的影响会应用在子模块中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class ParallelAttention(MegatronModule):     "" "Parallel self-attention layer abstract class.      Self-attention layer takes input with size [s, b, h]     and returns output of the same size.     " ""     def __init__(self, init_method,                  output_layer_init_method, layer_number,                  attention_type =AttnType.self_attn,                  attn_mask_type =AttnMaskType.padding):             .. .             self.query_key_value = tensor_parallel.ColumnParallelLinear(                 args.hidden_size,                 3 * projection_size,                 bias =args.add_bias_linear,                 gather_output =False ,                 init_method =init_method,                 async_tensor_model_parallel_allreduce =args.async_tensor_model_parallel_allreduce,                 **_args_to_kwargs())         .. .         self.core_attention = CoreAttention(self.layer_number,                                             self.attn_mask_type)         .. .         self.dense = tensor_parallel.RowParallelLinear(             projection_size,             args.hidden_size,             bias =args.add_bias_linear,             input_is_parallel =True ,             init_method =output_layer_init_method,             skip_bias_add =True ,             **_args_to_kwargs()) 
对于self.query_key_value和self.dense模块,fp16的设置能过参数中的**_args_to_kwargs()进行传递。
1 2 3 4 5 6 7 8 9 10 def  _args_to_kwargs ():    args = get_args()     common_kwargs = {         "params_dtype" : args.params_dtype,         "use_cpu_initialization" : args.use_cpu_initialization,         "perform_initialization" : args.perform_initialization,         "gradient_accumulation_fusion" : args.gradient_accumulation_fusion,         "sequence_parallel_enabled" : args.sequence_parallel,     }     return  common_kwargs 
对于self.core_attention部分,fp16的设置是在CoreAttention的__init__中self.fp16 = args.fp16。
1 2 3 4 5 6 7 8 9 class  CoreAttention (MegatronModule ):    def  __init__ (self , layer_number,                  attn_mask_type=AttnMaskType .padding ):        super (CoreAttention , self ).__init__()         args = get_args()         self .fp16 = args.fp16         self .bf16 = args.bf16     ... 
2.2 推理部分 
在ParallelAttention模块本身中fp16会影响推理部分
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 class  ParallelAttention (MegatronModule ):    def  __init__ (self, init_method,                   output_layer_init_method, layer_number,                  attention_type=AttnType.self_attn,                  attn_mask_type=AttnMaskType.padding ):        ...         self.params_dtype = args.params_dtype         ...     def  _allocate_memory (self, inference_max_sequence_len, batch_size ):         return  torch.empty(             inference_max_sequence_len,             batch_size,             self.num_attention_heads_per_partition,             self.hidden_size_per_attention_head,             dtype=self.params_dtype,             device=torch.cuda.current_device())     def  forward (self, hidden_states, attention_mask,                  encoder_output=None , inference_params=None ,                 rotary_pos_emb=None  ):        ...         if  inference_params:             if  self.layer_number not  in  inference_params.key_value_memory_dict:                 inf_max_seq_len = inference_params.max_sequence_len                 inf_max_batch_size = inference_params.max_batch_size                 inference_key_memory = self._allocate_memory(                     inf_max_seq_len, inf_max_batch_size)                 inference_value_memory = self._allocate_memory(                     inf_max_seq_len, inf_max_batch_size)         ... 
当指定了fp16以后,在ParallelAttention模型__init__初始化时会设置参数类型self.params_dtype为fp16 
在提前分配memory时_allocate_memory中会用torch.empty创建用于推理的大buffer,类型是fp16 
在指定推理参数inference_params时,forward函数中会调用_allocate_memory 
 
3. CoreAttention模块中fp16计算 
当设了fp16以后,在CoreAttention的forward计算的input就是fp16类型,在init中设置fp16
flag主要是用于计算中用到的FusedScaleMaskSoftmax模块的输出结果类型转换。
1 2 3 4 5 6 7 8 9 10 11 12 13 class  CoreAttention (MegatronModule ):    def  __init__ (self, layer_number,                   attn_mask_type=AttnMaskType.padding ):        ...         self.scale_mask_softmax = FusedScaleMaskSoftmax(             self.fp16, self.bf16,             self.attn_mask_type,             args.masked_softmax_fusion,             attention_mask_func,             self.attention_softmax_in_fp32,             coeff)     ... 
当FusedScaleMaskSoftmax执行时,kernel支持fp16时会直接调用fusion算子forward_fused_softmax;对于不支持的规模时,会调用forward_torch_softmax进行模拟,输出的类型就根据self.input_in_float16来进行cast转换。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class  FusedScaleMaskSoftmax (nn.Module):    ...     def  forward (self, input , mask ):                  assert  input .dim() == 4          if  self.is_kernel_available(mask, *input .size()):             return  self.forward_fused_softmax(input , mask)         else :             return  self.forward_torch_softmax(input , mask)     def  forward_torch_softmax (self, input , mask ):         if  self.input_in_float16 and  self.softmax_in_fp32:             input  = input .float ()         if  self.scale is  not  None :             input  = input  * self.scale         mask_output = self.mask_func(input , mask) if  mask is  not  None  else  input          probs = torch.nn.Softmax(dim=-1 )(mask_output)         if  self.input_in_float16 and  self.softmax_in_fp32:             if  self.input_in_fp16:                 probs = probs.half()             else :                 probs = probs.bfloat16()         return  probs 
4.
ColumnParallelLinear模块中fp16计算 
在ColumnParallelLinear初始化时创建Parameter中的类型直接按params_dtype(即fp16)来设。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 class  ColumnParallelLinear (torch.nn.Module):def  __init__ (self, ...,                  params_dtype=torch.float32,                  ...,                   ):            ...             self.weight = Parameter(torch.empty(                 self.output_size_per_partition, self.input_size,                 device=torch.cuda.current_device(), dtype=params_dtype))             ...                 self.bias = Parameter(torch.empty(                     self.output_size_per_partition,                     device=torch.cuda.current_device(),                     dtype=params_dtype))             ... 
5. lm-cross-entropy计算 
以gpt2模型为例,在megatron/model/gpt_model.py文件中的post_language_model_processing函数,
如果指定了fp16_lm_cross_entropy,那么在计算cross entropy时会把output先转为float32再进行计算loss。
1 2 3 4 5 if  fp16_lm_cross_entropy:    assert  output.dtype == torch.half     loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels) else :    loss = tensor_parallel.vocab_parallel_cross_entropy(output.float (), labels)