|
import random |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
from colossalai.booster.plugin import LowLevelZeroPlugin |
|
from colossalai.cluster import ProcessGroupMesh |
|
from torch.utils.data import DataLoader |
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
DP_AXIS, SP_AXIS = 0, 1 |
|
|
|
|
|
class ZeroSeqParallelPlugin(LowLevelZeroPlugin): |
|
def __init__( |
|
self, |
|
sp_size: int = 1, |
|
stage: int = 2, |
|
precision: str = "fp16", |
|
initial_scale: float = 2**32, |
|
min_scale: float = 1, |
|
growth_factor: float = 2, |
|
backoff_factor: float = 0.5, |
|
growth_interval: int = 1000, |
|
hysteresis: int = 2, |
|
max_scale: float = 2**32, |
|
max_norm: float = 0.0, |
|
norm_type: float = 2.0, |
|
reduce_bucket_size_in_m: int = 12, |
|
communication_dtype: Optional[torch.dtype] = None, |
|
overlap_communication: bool = True, |
|
cpu_offload: bool = False, |
|
master_weights: bool = True, |
|
verbose: bool = False, |
|
) -> None: |
|
super().__init__( |
|
stage=stage, |
|
precision=precision, |
|
initial_scale=initial_scale, |
|
min_scale=min_scale, |
|
growth_factor=growth_factor, |
|
backoff_factor=backoff_factor, |
|
growth_interval=growth_interval, |
|
hysteresis=hysteresis, |
|
max_scale=max_scale, |
|
max_norm=max_norm, |
|
norm_type=norm_type, |
|
reduce_bucket_size_in_m=reduce_bucket_size_in_m, |
|
communication_dtype=communication_dtype, |
|
overlap_communication=overlap_communication, |
|
cpu_offload=cpu_offload, |
|
master_weights=master_weights, |
|
verbose=verbose, |
|
) |
|
self.sp_size = sp_size |
|
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size" |
|
self.dp_size = self.world_size // sp_size |
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size) |
|
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS) |
|
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS) |
|
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS) |
|
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS) |
|
|
|
def __del__(self): |
|
"""Destroy the prcess groups in ProcessGroupMesh""" |
|
self.pg_mesh.destroy_mesh_process_groups() |
|
|
|
def prepare_dataloader( |
|
self, |
|
dataset, |
|
batch_size, |
|
shuffle=False, |
|
seed=1024, |
|
drop_last=False, |
|
pin_memory=False, |
|
num_workers=0, |
|
distributed_sampler_cls=None, |
|
**kwargs, |
|
): |
|
_kwargs = kwargs.copy() |
|
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler |
|
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle) |
|
|
|
|
|
def seed_worker(worker_id): |
|
worker_seed = seed |
|
np.random.seed(worker_seed) |
|
torch.manual_seed(worker_seed) |
|
random.seed(worker_seed) |
|
|
|
return DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
sampler=sampler, |
|
worker_init_fn=seed_worker, |
|
drop_last=drop_last, |
|
pin_memory=pin_memory, |
|
num_workers=num_workers, |
|
**_kwargs, |
|
) |
|
|