Skip to main content

Megatron Pipeline parallel走读

· 7 min read

最近在适配模型训练的工作中,需要修改 Megatron PP切分部分的代码,因此对 Megatron 这部分实现的代码进行了走读,做一下记录。

走读的代码分支为 core_r0.7.0,代码地址 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.7.0, PP切分的逻辑主要在 megatron/core/pipeline_parallel 目录下。

PP

直接看核心逻辑,megatron/core/pipeline_parallel/pipeline_parallel.py 中的 forward_backward_pipelining_without_interleaving 函数, 这个函数是 pipeline parallel 的核心逻辑,旨在以流水线方式处理模型的前向和后向传递。

note

后续贴出的代码片段,会省略一些参数检查和断言,以及一些不重要的逻辑。

总体来看,这个函数包括三个主要流程:

  • warmup
  • 1F1B
  • cooldown

其中,warmup阶段只进行前向,1F1B阶段进行前向和反向传播,cooldown阶段只进行反向。

函数定义及其它处理

关注函数定义中的几个重要的变量。

  • data_iterator 参数可以是单个迭代器或迭代器列表,提供要在每个micro batch中处理的数据。
  • model 参数可以是单个 torch.nn.Module 或模块列表,表示当前进程的模型或其片段。
  • num_microbatches 参数指定要将输入数据分成的micro batch数量,而 micro_batch_size 定义了序列长度和每个micro batch的大小。
def forward_backward_pipelining_with_interleaving(
forward_step_func,
data_iterator: Union[Iterator, List[Iterator]],
model: Union[torch.nn.Module, List[torch.nn.Module]],
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
)

正式进入到函数的逻辑,跳过一些函数的断言和参数检查。首先看到的是关于num_warmup_microbatches的计算, num_warmup_microbatches 指的是在warmup阶段,每个进程需要计算的batch数,越靠前的PP stage,warmup阶段的batch数越多。

其中,num_microbatches = global_batch_size / micro_batch_size / data_parallel_size

1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
1
2
2
3
3
4
4
# Compute number of warmup microbatches.
num_warmup_microbatches = (
parallel_state.get_pipeline_model_parallel_world_size()
- parallel_state.get_pipeline_model_parallel_rank()
- 1
)
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches

随后需要获取两个shape,recv_tensor_shapessend_tensor_shapes,这两个变量的命名和实际用途可能有一些出入, 这两个shape实际分别表示前一个PP stage接受/输出的shape,和当前PP stage接受/输出的shape。

在接受前一个PP stage的前向输出或者向前一个PP stage发送梯度时,需要知道recv_tensor_shapes。 在向后一个PP stage发送前向输出或者接受后一个PP stage的梯度时,则需要知道send_tensor_shapes

rank = parallel_state.get_pipeline_model_parallel_rank()
recv_tensor_shapes = get_tensor_shapes(
rank=rank - 1,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
send_tensor_shapes = get_tensor_shapes(
rank=rank,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)

warmup

1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
1
2
2
3
3
4
4

在warmup阶段,对每个warmup microbatch

  1. 调用recv_forward接受前一个PP stage的前向输出,(如果是第一个PP stage,forward_step会从dataloader中获取前向需要的数据)
  2. 然后执行forward_step,计算前向输出
  3. 调用send_forward将前向输出发送给下一个PP stage。
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(recv_tensor_shapes)
output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
current_microbatch=i,
)
send_forward(output_tensor, send_tensor_shapes)
total_num_tokens += num_tokens.item()

1F1B

1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
1
2
2
3
3
4
4

在1F1B阶段,如果在warmup阶段没有完成所有的microbatch,需要先接受前一个PP stage的前向输出。 对每一个未完成的microbatch

  1. 调用forward_step计算前向输出
  2. 调用send_forward_recv_backward将前向输出发送给下一个PP stage,并接受后一个PP stage的梯度
  3. 如果是最后一个microbatch,需要调用enable_grad_sync,启用梯度同步
  4. 调用backward_step计算梯度
  5. 如果是最后一个microbatch,调用send_backward将梯度发送给前一个PP stage,否则调用send_backward_recv_forward将梯度发送给前一个PP stage,并接受前一个PP stage的前向输出。
# Before running 1F1B, need to receive first forward tensor.
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(recv_tensor_shapes, config)

# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
last_iteration = i == (num_microbatches_remaining - 1)

output_tensor, num_tokens = forward_step(
forward_step_func,
data_iterator,
model,
num_microbatches,
input_tensor,
current_microbatch=i + num_warmup_microbatches,
)
total_num_tokens += num_tokens.item()

output_tensor_grad = send_forward_recv_backward(
output_tensor, send_tensor_shapes, config
)

# Enable grad sync for the last microbatch in the batch if the full
# backward pass completes in the 1F1B stage.
if num_warmup_microbatches == 0 and last_iteration:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()

input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)

if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, recv_tensor_shapes, config)
else:
input_tensor = send_backward_recv_forward(
input_tensor_grad, recv_tensor_shapes, config
)

cooldown

在cooldown阶段,恰好与warmup阶段相反,只进行反向传播,对每个warmup microbatch

  1. 调用recv_backward接受后一个PP stage的梯度
  2. 调用backward_step计算梯度
  3. 调用send_backward将梯度发送给前一个PP stage
  4. 如果是最后一个microbatch,调用enable_grad_sync,启用梯度同步
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):

# Enable async grad reduction in the last backward pass
# Note: If grad sync function is provided, only enable
# async grad reduction in first pipeline stage. Other
# pipeline stages do grad reduction during pipeline
# bubble.
if i == num_warmup_microbatches - 1:
if config.grad_sync_func is None or rank == 0:
enable_grad_sync()

output_tensor_grad = recv_backward(send_tensor_shapes, config)

input_tensor_grad = backward_step(
input_tensor, output_tensor, output_tensor_grad, model_type, config
)

send_backward(input_tensor_grad, recv_tensor_shapes, config)
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
1
2
2
3
3
4
4

VPP

待补充

Bubble计算

1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
2
3
4
1
1
2
2
3
3
4
4
Bubble比例: 0.75