group.add_argument('--recompute-activations', action='store_true', help='recompute activation to allow for training ' 'with larger models, sequences, and batch sizes.') group.add_argument('--recompute-granularity', type=str, default=None, choices=['full', 'selective'], help='Checkpoint activations to allow for training ' 'with larger models, sequences, and batch sizes. ' 'It is supported at two granularities 1) full: ' 'whole transformer layer is recomputed, ' '2) selective: core attention part of the transformer ' 'layer is recomputed.') group.add_argument('--distribute-saved-activations', action='store_true', help='If set, distribute recomputed activations ' 'across model parallel group.') group.add_argument('--recompute-method', type=str, default=None, choices=['uniform', 'block'], help='1) uniform: uniformly divide the total number of ' 'Transformer layers and recompute the input activation of ' 'each divided chunk at specified granularity, ' '2) recompute the input activations of only a set number of ' 'individual Transformer layers per pipeline stage and do the ' 'rest without any recomputing at specified granularity' 'default) do not apply activations recompute to any layers') group.add_argument('--recompute-num-layers', type=int, default=1, help='1) uniform: the number of Transformer layers in each ' 'uniformly divided recompute unit, ' '2) block: the number of individual Transformer layers ' 'to recompute within each pipeline stage.')
if self.recompute_method == 'uniform': # Uniformly divide the total number of Transformer layers and # checkpoint the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: if self.transformer_impl == 'transformer_engine': ... else: hidden_states = tensor_parallel.checkpoint( custom(l, l + self.recompute_num_layers), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, None, None, None, None, rotary_pos_emb)
l += self.recompute_num_layers elif self.recompute_method == 'block': # Checkpoint the input activation of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory removing redundant re-computation. for l inrange(self.num_layers): if l < self.recompute_num_layers: if self.transformer_impl == 'transformer_engine': ... else: hidden_states = tensor_parallel.checkpoint( custom(l, l + 1), self.distribute_saved_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, None, None, None, None, rotary_pos_emb) else: if self.transformer_impl == 'transformer_engine': ... else: hidden_states = custom(l, l + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask, None, None, None, None, rotary_pos_emb) ...
对于selective颗粒度的重计算目前不支持重计算方法的选择。
1 2 3 4
if args.recompute_granularity == 'selective': assert args.recompute_method isNone, \ 'recompute method is not yet supported for ' \ 'selective recomputing granularity'
# Activation recomputing. if args.distribute_saved_activations: assert args.tensor_model_parallel_size > 1, 'can distribute ' \ 'recomputed activations only across tensor model ' \ 'parallel groups' assert args.recompute_granularity == 'full', \ 'distributed recompute activations is only '\ 'application to full recompute granularity' assert args.recompute_method isnotNone, \ 'for distributed recompute activations to work you '\ 'need to use a recompute method ' assert TORCH_MAJOR >= 1and TORCH_MINOR >= 10, \ 'distributed recompute activations are supported for pytorch ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)
with torch.no_grad(): outputs = run_function(*args)
# Divide hidden states across model parallel group and only keep # the chunk corresponding to the current rank. if distribute_saved_activations: ctx.input_0_shape = args[0].data.shape safely_set_viewless_tensor_data( args[0], split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) ...