import collections from typing import Callable import torch from torch import distributed from torch.nn.functional import linear from torch.nn.functional import normalize class PartialFC(torch.nn.Module): """ https://arxiv.org/abs/2203.15565 A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). When sample rate less than 1, in each iteration, positive class centers and a random subset of negative class centers are selected to compute the margin-based softmax loss, all class centers are still maintained throughout the whole training process, but only a subset is selected and updated in each iteration. .. note:: When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). Example: -------- >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) >>> for img, labels in data_loader: >>> embeddings = net(img) >>> loss = module_pfc(embeddings, labels, optimizer) >>> loss.backward() >>> optimizer.step() """ _version = 1 def __init__( self, margin_loss: Callable, embedding_size: int, num_classes: int, sample_rate: float = 1.0, fp16: bool = False, ): """ Paramenters: ----------- embedding_size: int The dimension of embedding, required num_classes: int Total number of classes, required sample_rate: float The rate of negative centers participating in the calculation, default is 1.0. """ super(PartialFC, self).__init__() assert distributed.is_initialized(), "must initialize distributed before create this" self.rank = distributed.get_rank() self.world_size = distributed.get_world_size() self.dist_cross_entropy = DistCrossEntropy() self.embedding_size = embedding_size self.sample_rate: float = sample_rate self.fp16 = fp16 self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size) self.class_start: int = num_classes // self.world_size * self.rank + min( self.rank, num_classes % self.world_size ) self.num_sample: int = int(self.sample_rate * self.num_local) self.last_batch_size: int = 0 self.weight: torch.Tensor self.weight_mom: torch.Tensor self.weight_activated: torch.nn.Parameter self.weight_activated_mom: torch.Tensor self.is_updated: bool = True self.init_weight_update: bool = True if self.sample_rate < 1: self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) self.register_buffer("weight_mom", tensor=torch.zeros_like(self.weight)) self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0))) self.register_buffer("weight_activated_mom", tensor=torch.empty(0, 0)) self.register_buffer("weight_index", tensor=torch.empty(0, 0)) else: self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) # margin_loss if isinstance(margin_loss, Callable): self.margin_softmax = margin_loss else: raise @torch.no_grad() def sample(self, labels: torch.Tensor, index_positive: torch.Tensor, optimizer: torch.optim.Optimizer): """ This functions will change the value of labels Parameters: ----------- labels: torch.Tensor pass index_positive: torch.Tensor pass optimizer: torch.optim.Optimizer pass """ positive = torch.unique(labels[index_positive], sorted=True).cuda() if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local]).cuda() perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1].cuda() index = index.sort()[0].cuda() else: index = positive self.weight_index = index labels[index_positive] = torch.searchsorted(index, labels[index_positive]) self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) self.weight_activated_mom = self.weight_mom[self.weight_index] if isinstance(optimizer, torch.optim.SGD): # TODO the params of partial fc must be last in the params list optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) optimizer.param_groups[-1]["params"][0] = self.weight_activated optimizer.state[self.weight_activated]["momentum_buffer"] = self.weight_activated_mom else: raise @torch.no_grad() def update(self): """partial weight to global""" if self.init_weight_update: self.init_weight_update = False return if self.sample_rate < 1: self.weight[self.weight_index] = self.weight_activated self.weight_mom[self.weight_index] = self.weight_activated_mom def forward( self, local_embeddings: torch.Tensor, local_labels: torch.Tensor, optimizer: torch.optim.Optimizer, ): """ Parameters: ---------- local_embeddings: torch.Tensor feature embeddings on each GPU(Rank). local_labels: torch.Tensor labels on each GPU(Rank). Returns: ------- loss: torch.Tensor pass """ local_labels.squeeze_() local_labels = local_labels.long() self.update() batch_size = local_embeddings.size(0) if self.last_batch_size == 0: self.last_batch_size = batch_size assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format( self.last_batch_size, batch_size ) _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)] _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)] _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) distributed.all_gather(_gather_labels, local_labels) embeddings = torch.cat(_list_embeddings) labels = torch.cat(_gather_labels) labels = labels.view(-1, 1) index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local) labels[~index_positive] = -1 labels[index_positive] -= self.class_start if self.sample_rate < 1: self.sample(labels, index_positive, optimizer) with torch.cuda.amp.autocast(self.fp16): norm_embeddings = normalize(embeddings) norm_weight_activated = normalize(self.weight_activated) logits = linear(norm_embeddings, norm_weight_activated) if self.fp16: logits = logits.float() logits = logits.clamp(-1, 1) logits = self.margin_softmax(logits, labels) loss = self.dist_cross_entropy(logits, labels) return loss def state_dict(self, destination=None, prefix="", keep_vars=False): if destination is None: destination = collections.OrderedDict() destination._metadata = collections.OrderedDict() for name, module in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) if self.sample_rate < 1: destination["weight"] = self.weight.detach() else: destination["weight"] = self.weight_activated.data.detach() return destination def load_state_dict(self, state_dict, strict: bool = True): if self.sample_rate < 1: self.weight = state_dict["weight"].to(self.weight.device) self.weight_mom.zero_() self.weight_activated.data.zero_() self.weight_activated_mom.zero_() self.weight_index.zero_() else: self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) class PartialFCAdamW(torch.nn.Module): def __init__( self, margin_loss: Callable, embedding_size: int, num_classes: int, sample_rate: float = 1.0, fp16: bool = False, ): """ Paramenters: ----------- embedding_size: int The dimension of embedding, required num_classes: int Total number of classes, required sample_rate: float The rate of negative centers participating in the calculation, default is 1.0. """ super(PartialFCAdamW, self).__init__() assert distributed.is_initialized(), "must initialize distributed before create this" self.rank = distributed.get_rank() self.world_size = distributed.get_world_size() self.dist_cross_entropy = DistCrossEntropy() self.embedding_size = embedding_size self.sample_rate: float = sample_rate self.fp16 = fp16 self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size) self.class_start: int = num_classes // self.world_size * self.rank + min( self.rank, num_classes % self.world_size ) self.num_sample: int = int(self.sample_rate * self.num_local) self.last_batch_size: int = 0 self.weight: torch.Tensor self.weight_exp_avg: torch.Tensor self.weight_exp_avg_sq: torch.Tensor self.weight_activated: torch.nn.Parameter self.weight_activated_exp_avg: torch.Tensor self.weight_activated_exp_avg_sq: torch.Tensor self.is_updated: bool = True self.init_weight_update: bool = True if self.sample_rate < 1: self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) self.register_buffer("weight_exp_avg", tensor=torch.zeros_like(self.weight)) self.register_buffer("weight_exp_avg_sq", tensor=torch.zeros_like(self.weight)) self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0))) self.register_buffer("weight_activated_exp_avg", tensor=torch.empty(0, 0)) self.register_buffer("weight_activated_exp_avg_sq", tensor=torch.empty(0, 0)) else: self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) self.step = 0 if isinstance(margin_loss, Callable): self.margin_softmax = margin_loss else: raise @torch.no_grad() def sample(self, labels, index_positive, optimizer): self.step += 1 positive = torch.unique(labels[index_positive], sorted=True).cuda() if self.num_sample - positive.size(0) >= 0: perm = torch.rand(size=[self.num_local]).cuda() perm[positive] = 2.0 index = torch.topk(perm, k=self.num_sample)[1].cuda() index = index.sort()[0].cuda() else: index = positive self.weight_index = index labels[index_positive] = torch.searchsorted(index, labels[index_positive]) self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index] self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index] if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): # TODO the params of partial fc must be last in the params list optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) optimizer.param_groups[-1]["params"][0] = self.weight_activated optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq optimizer.state[self.weight_activated]["step"] = self.step else: raise @torch.no_grad() def update(self): """partial weight to global""" if self.init_weight_update: self.init_weight_update = False return if self.sample_rate < 1: self.weight[self.weight_index] = self.weight_activated self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq def forward( self, local_embeddings: torch.Tensor, local_labels: torch.Tensor, optimizer: torch.optim.Optimizer, ): """ Parameters: ---------- local_embeddings: torch.Tensor feature embeddings on each GPU(Rank). local_labels: torch.Tensor labels on each GPU(Rank). Returns: ------- loss: torch.Tensor pass """ local_labels.squeeze_() local_labels = local_labels.long() self.update() batch_size = local_embeddings.size(0) if self.last_batch_size == 0: self.last_batch_size = batch_size assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format( self.last_batch_size, batch_size ) _gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)] _gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)] _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) distributed.all_gather(_gather_labels, local_labels) embeddings = torch.cat(_list_embeddings) labels = torch.cat(_gather_labels) labels = labels.view(-1, 1) index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local) labels[~index_positive] = -1 labels[index_positive] -= self.class_start if self.sample_rate < 1: self.sample(labels, index_positive, optimizer) with torch.cuda.amp.autocast(self.fp16): norm_embeddings = normalize(embeddings) norm_weight_activated = normalize(self.weight_activated) logits = linear(norm_embeddings, norm_weight_activated) if self.fp16: logits = logits.float() logits = logits.clamp(-1, 1) logits = self.margin_softmax(logits, labels) loss = self.dist_cross_entropy(logits, labels) return loss def state_dict(self, destination=None, prefix="", keep_vars=False): if destination is None: destination = collections.OrderedDict() destination._metadata = collections.OrderedDict() for name, module in self._modules.items(): if module is not None: module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) if self.sample_rate < 1: destination["weight"] = self.weight.detach() else: destination["weight"] = self.weight_activated.data.detach() return destination def load_state_dict(self, state_dict, strict: bool = True): if self.sample_rate < 1: self.weight = state_dict["weight"].to(self.weight.device) self.weight_exp_avg.zero_() self.weight_exp_avg_sq.zero_() self.weight_activated.data.zero_() self.weight_activated_exp_avg.zero_() self.weight_activated_exp_avg_sq.zero_() else: self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) class DistCrossEntropyFunc(torch.autograd.Function): """ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): """ @staticmethod def forward(ctx, logits: torch.Tensor, label: torch.Tensor): """ """ batch_size = logits.size(0) # for numerical stability max_logits, _ = torch.max(logits, dim=1, keepdim=True) # local to global distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) logits.sub_(max_logits) logits.exp_() sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) # local to global distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) logits.div_(sum_logits_exp) index = torch.where(label != -1)[0] # loss loss = torch.zeros(batch_size, 1, device=logits.device) loss[index] = logits[index].gather(1, label[index]) distributed.all_reduce(loss, distributed.ReduceOp.SUM) ctx.save_for_backward(index, logits, label) return loss.clamp_min_(1e-30).log_().mean() * (-1) @staticmethod def backward(ctx, loss_gradient): """ Args: loss_grad (torch.Tensor): gradient backward by last layer Returns: gradients for each input in forward function `None` gradients for one-hot label """ ( index, logits, label, ) = ctx.saved_tensors batch_size = logits.size(0) one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device) one_hot.scatter_(1, label[index], 1) logits[index] -= one_hot logits.div_(batch_size) return logits * loss_gradient.item(), None class DistCrossEntropy(torch.nn.Module): def __init__(self): super(DistCrossEntropy, self).__init__() def forward(self, logit_part, label_part): return DistCrossEntropyFunc.apply(logit_part, label_part) class AllGatherFunc(torch.autograd.Function): """AllGather op with gradient backward""" @staticmethod def forward(ctx, tensor, *gather_list): gather_list = list(gather_list) distributed.all_gather(gather_list, tensor) return tuple(gather_list) @staticmethod def backward(ctx, *grads): grad_list = list(grads) rank = distributed.get_rank() grad_out = grad_list[rank] dist_ops = [ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) if i == rank else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True) for i in range(distributed.get_world_size()) ] for _op in dist_ops: _op.wait() grad_out *= len(grad_list) # cooperate with distributed loss function return (grad_out, *[None for _ in range(len(grad_list))]) AllGather = AllGatherFunc.apply