|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
|
|
|
|
|
|
def _all_to_all( |
|
input_: torch.Tensor, |
|
world_size: int, |
|
group: dist.ProcessGroup, |
|
scatter_dim: int, |
|
gather_dim: int, |
|
): |
|
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
|
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
|
dist.all_to_all(output_list, input_list, group=group) |
|
return torch.cat(output_list, dim=gather_dim).contiguous() |
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
"""All-to-all communication. |
|
|
|
Args: |
|
input_: input matrix |
|
process_group: communication group |
|
scatter_dim: scatter dimension |
|
gather_dim: gather dimension |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
|
ctx.process_group = process_group |
|
ctx.scatter_dim = scatter_dim |
|
ctx.gather_dim = gather_dim |
|
ctx.world_size = dist.get_world_size(process_group) |
|
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_output = _all_to_all( |
|
grad_output, |
|
ctx.world_size, |
|
ctx.process_group, |
|
ctx.gather_dim, |
|
ctx.scatter_dim, |
|
) |
|
return ( |
|
grad_output, |
|
None, |
|
None, |
|
None, |
|
) |
|
|
|
|
|
def all_to_all( |
|
input_: torch.Tensor, |
|
process_group: dist.ProcessGroup, |
|
scatter_dim: int = 2, |
|
gather_dim: int = 1, |
|
): |
|
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
|
|
|
|
|
def _gather( |
|
input_: torch.Tensor, |
|
world_size: int, |
|
group: dist.ProcessGroup, |
|
gather_dim: int, |
|
): |
|
if gather_list is None: |
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
dist.gather(input_, gather_list, group=group, gather_dim=gather_dim) |
|
return gather_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split(input_, pg: dist.ProcessGroup, dim=-1): |
|
|
|
world_size = dist.get_world_size(pg) |
|
rank = dist.get_rank(pg) |
|
if world_size == 1: |
|
return input_ |
|
|
|
|
|
dim_size = input_.size(dim) |
|
assert dim_size % world_size == 0, ( |
|
f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " |
|
f"cannot split tensor evenly" |
|
) |
|
|
|
tensor_list = torch.split(input_, dim_size // world_size, dim=dim) |
|
output = tensor_list[rank].contiguous() |
|
|
|
return output |
|
|
|
|
|
def _gather(input_, pg: dist.ProcessGroup, dim=-1): |
|
|
|
input_ = input_.contiguous() |
|
world_size = dist.get_world_size(pg) |
|
dist.get_rank(pg) |
|
|
|
if world_size == 1: |
|
return input_ |
|
|
|
|
|
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
|
assert input_.device.type == "cuda" |
|
torch.distributed.all_gather(tensor_list, input_, group=pg) |
|
|
|
|
|
output = torch.cat(tensor_list, dim=dim).contiguous() |
|
|
|
return output |
|
|
|
|
|
class _GatherForwardSplitBackward(torch.autograd.Function): |
|
"""Gather the input from model parallel region and concatenate. |
|
|
|
Args: |
|
input_: input matrix. |
|
process_group: parallel mode. |
|
dim: dimension |
|
""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _gather(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_, process_group, dim, grad_scale): |
|
ctx.mode = process_group |
|
ctx.dim = dim |
|
ctx.grad_scale = grad_scale |
|
return _gather(input_, process_group, dim) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
if ctx.grad_scale == "up": |
|
grad_output = grad_output * dist.get_world_size(ctx.mode) |
|
elif ctx.grad_scale == "down": |
|
grad_output = grad_output / dist.get_world_size(ctx.mode) |
|
|
|
return _split(grad_output, ctx.mode, ctx.dim), None, None, None |
|
|
|
|
|
class _SplitForwardGatherBackward(torch.autograd.Function): |
|
""" |
|
Split the input and keep only the corresponding chuck to the rank. |
|
|
|
Args: |
|
input_: input matrix. |
|
process_group: parallel mode. |
|
dim: dimension |
|
""" |
|
|
|
@staticmethod |
|
def symbolic(graph, input_): |
|
return _split(input_) |
|
|
|
@staticmethod |
|
def forward(ctx, input_, process_group, dim, grad_scale): |
|
ctx.mode = process_group |
|
ctx.dim = dim |
|
ctx.grad_scale = grad_scale |
|
return _split(input_, process_group, dim) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
if ctx.grad_scale == "up": |
|
grad_output = grad_output * dist.get_world_size(ctx.mode) |
|
elif ctx.grad_scale == "down": |
|
grad_output = grad_output / dist.get_world_size(ctx.mode) |
|
return _gather(grad_output, ctx.mode, ctx.dim), None, None, None |
|
|
|
|
|
def split_forward_gather_backward(input_, process_group, dim, grad_scale=1.0): |
|
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale) |
|
|
|
|
|
def gather_forward_split_backward(input_, process_group, dim, grad_scale=None): |
|
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale) |
|
|