Spaces:
Runtime error
Runtime error
import collections | |
from typing import Callable | |
import torch | |
from torch import distributed | |
from torch.nn.functional import linear | |
from torch.nn.functional import normalize | |
class PartialFC(torch.nn.Module): | |
""" | |
https://arxiv.org/abs/2203.15565 | |
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). | |
When sample rate less than 1, in each iteration, positive class centers and a random subset of | |
negative class centers are selected to compute the margin-based softmax loss, all class | |
centers are still maintained throughout the whole training process, but only a subset is | |
selected and updated in each iteration. | |
.. note:: | |
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). | |
Example: | |
-------- | |
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) | |
>>> for img, labels in data_loader: | |
>>> embeddings = net(img) | |
>>> loss = module_pfc(embeddings, labels, optimizer) | |
>>> loss.backward() | |
>>> optimizer.step() | |
""" | |
_version = 1 | |
def __init__( | |
self, | |
margin_loss: Callable, | |
embedding_size: int, | |
num_classes: int, | |
sample_rate: float = 1.0, | |
fp16: bool = False, | |
): | |
""" | |
Paramenters: | |
----------- | |
embedding_size: int | |
The dimension of embedding, required | |
num_classes: int | |
Total number of classes, required | |
sample_rate: float | |
The rate of negative centers participating in the calculation, default is 1.0. | |
""" | |
super(PartialFC, self).__init__() | |
assert distributed.is_initialized(), "must initialize distributed before create this" | |
self.rank = distributed.get_rank() | |
self.world_size = distributed.get_world_size() | |
self.dist_cross_entropy = DistCrossEntropy() | |
self.embedding_size = embedding_size | |
self.sample_rate: float = sample_rate | |
self.fp16 = fp16 | |
self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size) | |
self.class_start: int = num_classes // self.world_size * self.rank + min( | |
self.rank, num_classes % self.world_size | |
) | |
self.num_sample: int = int(self.sample_rate * self.num_local) | |
self.last_batch_size: int = 0 | |
self.weight: torch.Tensor | |
self.weight_mom: torch.Tensor | |
self.weight_activated: torch.nn.Parameter | |
self.weight_activated_mom: torch.Tensor | |
self.is_updated: bool = True | |
self.init_weight_update: bool = True | |
if self.sample_rate < 1: | |
self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) | |
self.register_buffer("weight_mom", tensor=torch.zeros_like(self.weight)) | |
self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0))) | |
self.register_buffer("weight_activated_mom", tensor=torch.empty(0, 0)) | |
self.register_buffer("weight_index", tensor=torch.empty(0, 0)) | |
else: | |
self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) | |
# margin_loss | |
if isinstance(margin_loss, Callable): | |
self.margin_softmax = margin_loss | |
else: | |
raise | |
def sample(self, labels: torch.Tensor, index_positive: torch.Tensor, optimizer: torch.optim.Optimizer): | |
""" | |
This functions will change the value of labels | |
Parameters: | |
----------- | |
labels: torch.Tensor | |
pass | |
index_positive: torch.Tensor | |
pass | |
optimizer: torch.optim.Optimizer | |
pass | |
""" | |
positive = torch.unique(labels[index_positive], sorted=True).cuda() | |
if self.num_sample - positive.size(0) >= 0: | |
perm = torch.rand(size=[self.num_local]).cuda() | |
perm[positive] = 2.0 | |
index = torch.topk(perm, k=self.num_sample)[1].cuda() | |
index = index.sort()[0].cuda() | |
else: | |
index = positive | |
self.weight_index = index | |
labels[index_positive] = torch.searchsorted(index, labels[index_positive]) | |
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) | |
self.weight_activated_mom = self.weight_mom[self.weight_index] | |
if isinstance(optimizer, torch.optim.SGD): | |
# TODO the params of partial fc must be last in the params list | |
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) | |
optimizer.param_groups[-1]["params"][0] = self.weight_activated | |
optimizer.state[self.weight_activated]["momentum_buffer"] = self.weight_activated_mom | |
else: | |
raise | |
def update(self): | |
"""partial weight to global""" | |
if self.init_weight_update: | |
self.init_weight_update = False | |
return | |
if self.sample_rate < 1: | |
self.weight[self.weight_index] = self.weight_activated | |
self.weight_mom[self.weight_index] = self.weight_activated_mom | |
def forward( | |
self, | |
local_embeddings: torch.Tensor, | |
local_labels: torch.Tensor, | |
optimizer: torch.optim.Optimizer, | |
): | |
""" | |
Parameters: | |
---------- | |
local_embeddings: torch.Tensor | |
feature embeddings on each GPU(Rank). | |
local_labels: torch.Tensor | |
labels on each GPU(Rank). | |
Returns: | |
------- | |
loss: torch.Tensor | |
pass | |
""" | |
local_labels.squeeze_() | |
local_labels = local_labels.long() | |
self.update() | |
batch_size = local_embeddings.size(0) | |
if self.last_batch_size == 0: | |
self.last_batch_size = batch_size | |
assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format( | |
self.last_batch_size, batch_size | |
) | |
_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)] | |
_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)] | |
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings) | |
distributed.all_gather(_gather_labels, local_labels) | |
embeddings = torch.cat(_list_embeddings) | |
labels = torch.cat(_gather_labels) | |
labels = labels.view(-1, 1) | |
index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local) | |
labels[~index_positive] = -1 | |
labels[index_positive] -= self.class_start | |
if self.sample_rate < 1: | |
self.sample(labels, index_positive, optimizer) | |
with torch.cuda.amp.autocast(self.fp16): | |
norm_embeddings = normalize(embeddings) | |
norm_weight_activated = normalize(self.weight_activated) | |
logits = linear(norm_embeddings, norm_weight_activated) | |
if self.fp16: | |
logits = logits.float() | |
logits = logits.clamp(-1, 1) | |
logits = self.margin_softmax(logits, labels) | |
loss = self.dist_cross_entropy(logits, labels) | |
return loss | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
if destination is None: | |
destination = collections.OrderedDict() | |
destination._metadata = collections.OrderedDict() | |
for name, module in self._modules.items(): | |
if module is not None: | |
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) | |
if self.sample_rate < 1: | |
destination["weight"] = self.weight.detach() | |
else: | |
destination["weight"] = self.weight_activated.data.detach() | |
return destination | |
def load_state_dict(self, state_dict, strict: bool = True): | |
if self.sample_rate < 1: | |
self.weight = state_dict["weight"].to(self.weight.device) | |
self.weight_mom.zero_() | |
self.weight_activated.data.zero_() | |
self.weight_activated_mom.zero_() | |
self.weight_index.zero_() | |
else: | |
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) | |
class PartialFCAdamW(torch.nn.Module): | |
def __init__( | |
self, | |
margin_loss: Callable, | |
embedding_size: int, | |
num_classes: int, | |
sample_rate: float = 1.0, | |
fp16: bool = False, | |
): | |
""" | |
Paramenters: | |
----------- | |
embedding_size: int | |
The dimension of embedding, required | |
num_classes: int | |
Total number of classes, required | |
sample_rate: float | |
The rate of negative centers participating in the calculation, default is 1.0. | |
""" | |
super(PartialFCAdamW, self).__init__() | |
assert distributed.is_initialized(), "must initialize distributed before create this" | |
self.rank = distributed.get_rank() | |
self.world_size = distributed.get_world_size() | |
self.dist_cross_entropy = DistCrossEntropy() | |
self.embedding_size = embedding_size | |
self.sample_rate: float = sample_rate | |
self.fp16 = fp16 | |
self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size) | |
self.class_start: int = num_classes // self.world_size * self.rank + min( | |
self.rank, num_classes % self.world_size | |
) | |
self.num_sample: int = int(self.sample_rate * self.num_local) | |
self.last_batch_size: int = 0 | |
self.weight: torch.Tensor | |
self.weight_exp_avg: torch.Tensor | |
self.weight_exp_avg_sq: torch.Tensor | |
self.weight_activated: torch.nn.Parameter | |
self.weight_activated_exp_avg: torch.Tensor | |
self.weight_activated_exp_avg_sq: torch.Tensor | |
self.is_updated: bool = True | |
self.init_weight_update: bool = True | |
if self.sample_rate < 1: | |
self.register_buffer("weight", tensor=torch.normal(0, 0.01, (self.num_local, embedding_size))) | |
self.register_buffer("weight_exp_avg", tensor=torch.zeros_like(self.weight)) | |
self.register_buffer("weight_exp_avg_sq", tensor=torch.zeros_like(self.weight)) | |
self.register_parameter("weight_activated", param=torch.nn.Parameter(torch.empty(0, 0))) | |
self.register_buffer("weight_activated_exp_avg", tensor=torch.empty(0, 0)) | |
self.register_buffer("weight_activated_exp_avg_sq", tensor=torch.empty(0, 0)) | |
else: | |
self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) | |
self.step = 0 | |
if isinstance(margin_loss, Callable): | |
self.margin_softmax = margin_loss | |
else: | |
raise | |
def sample(self, labels, index_positive, optimizer): | |
self.step += 1 | |
positive = torch.unique(labels[index_positive], sorted=True).cuda() | |
if self.num_sample - positive.size(0) >= 0: | |
perm = torch.rand(size=[self.num_local]).cuda() | |
perm[positive] = 2.0 | |
index = torch.topk(perm, k=self.num_sample)[1].cuda() | |
index = index.sort()[0].cuda() | |
else: | |
index = positive | |
self.weight_index = index | |
labels[index_positive] = torch.searchsorted(index, labels[index_positive]) | |
self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index]) | |
self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index] | |
self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index] | |
if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)): | |
# TODO the params of partial fc must be last in the params list | |
optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None) | |
optimizer.param_groups[-1]["params"][0] = self.weight_activated | |
optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg | |
optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq | |
optimizer.state[self.weight_activated]["step"] = self.step | |
else: | |
raise | |
def update(self): | |
"""partial weight to global""" | |
if self.init_weight_update: | |
self.init_weight_update = False | |
return | |
if self.sample_rate < 1: | |
self.weight[self.weight_index] = self.weight_activated | |
self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg | |
self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq | |
def forward( | |
self, | |
local_embeddings: torch.Tensor, | |
local_labels: torch.Tensor, | |
optimizer: torch.optim.Optimizer, | |
): | |
""" | |
Parameters: | |
---------- | |
local_embeddings: torch.Tensor | |
feature embeddings on each GPU(Rank). | |
local_labels: torch.Tensor | |
labels on each GPU(Rank). | |
Returns: | |
------- | |
loss: torch.Tensor | |
pass | |
""" | |
local_labels.squeeze_() | |
local_labels = local_labels.long() | |
self.update() | |
batch_size = local_embeddings.size(0) | |
if self.last_batch_size == 0: | |
self.last_batch_size = batch_size | |
assert self.last_batch_size == batch_size, "last batch size do not equal current batch size: {} vs {}".format( | |
self.last_batch_size, batch_size | |
) | |
_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda() for _ in range(self.world_size)] | |
_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)] | |
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings) | |
distributed.all_gather(_gather_labels, local_labels) | |
embeddings = torch.cat(_list_embeddings) | |
labels = torch.cat(_gather_labels) | |
labels = labels.view(-1, 1) | |
index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local) | |
labels[~index_positive] = -1 | |
labels[index_positive] -= self.class_start | |
if self.sample_rate < 1: | |
self.sample(labels, index_positive, optimizer) | |
with torch.cuda.amp.autocast(self.fp16): | |
norm_embeddings = normalize(embeddings) | |
norm_weight_activated = normalize(self.weight_activated) | |
logits = linear(norm_embeddings, norm_weight_activated) | |
if self.fp16: | |
logits = logits.float() | |
logits = logits.clamp(-1, 1) | |
logits = self.margin_softmax(logits, labels) | |
loss = self.dist_cross_entropy(logits, labels) | |
return loss | |
def state_dict(self, destination=None, prefix="", keep_vars=False): | |
if destination is None: | |
destination = collections.OrderedDict() | |
destination._metadata = collections.OrderedDict() | |
for name, module in self._modules.items(): | |
if module is not None: | |
module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars) | |
if self.sample_rate < 1: | |
destination["weight"] = self.weight.detach() | |
else: | |
destination["weight"] = self.weight_activated.data.detach() | |
return destination | |
def load_state_dict(self, state_dict, strict: bool = True): | |
if self.sample_rate < 1: | |
self.weight = state_dict["weight"].to(self.weight.device) | |
self.weight_exp_avg.zero_() | |
self.weight_exp_avg_sq.zero_() | |
self.weight_activated.data.zero_() | |
self.weight_activated_exp_avg.zero_() | |
self.weight_activated_exp_avg_sq.zero_() | |
else: | |
self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device) | |
class DistCrossEntropyFunc(torch.autograd.Function): | |
""" | |
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. | |
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): | |
""" | |
def forward(ctx, logits: torch.Tensor, label: torch.Tensor): | |
""" """ | |
batch_size = logits.size(0) | |
# for numerical stability | |
max_logits, _ = torch.max(logits, dim=1, keepdim=True) | |
# local to global | |
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) | |
logits.sub_(max_logits) | |
logits.exp_() | |
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) | |
# local to global | |
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) | |
logits.div_(sum_logits_exp) | |
index = torch.where(label != -1)[0] | |
# loss | |
loss = torch.zeros(batch_size, 1, device=logits.device) | |
loss[index] = logits[index].gather(1, label[index]) | |
distributed.all_reduce(loss, distributed.ReduceOp.SUM) | |
ctx.save_for_backward(index, logits, label) | |
return loss.clamp_min_(1e-30).log_().mean() * (-1) | |
def backward(ctx, loss_gradient): | |
""" | |
Args: | |
loss_grad (torch.Tensor): gradient backward by last layer | |
Returns: | |
gradients for each input in forward function | |
`None` gradients for one-hot label | |
""" | |
( | |
index, | |
logits, | |
label, | |
) = ctx.saved_tensors | |
batch_size = logits.size(0) | |
one_hot = torch.zeros(size=[index.size(0), logits.size(1)], device=logits.device) | |
one_hot.scatter_(1, label[index], 1) | |
logits[index] -= one_hot | |
logits.div_(batch_size) | |
return logits * loss_gradient.item(), None | |
class DistCrossEntropy(torch.nn.Module): | |
def __init__(self): | |
super(DistCrossEntropy, self).__init__() | |
def forward(self, logit_part, label_part): | |
return DistCrossEntropyFunc.apply(logit_part, label_part) | |
class AllGatherFunc(torch.autograd.Function): | |
"""AllGather op with gradient backward""" | |
def forward(ctx, tensor, *gather_list): | |
gather_list = list(gather_list) | |
distributed.all_gather(gather_list, tensor) | |
return tuple(gather_list) | |
def backward(ctx, *grads): | |
grad_list = list(grads) | |
rank = distributed.get_rank() | |
grad_out = grad_list[rank] | |
dist_ops = [ | |
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) | |
if i == rank | |
else distributed.reduce(grad_list[i], i, distributed.ReduceOp.SUM, async_op=True) | |
for i in range(distributed.get_world_size()) | |
] | |
for _op in dist_ops: | |
_op.wait() | |
grad_out *= len(grad_list) # cooperate with distributed loss function | |
return (grad_out, *[None for _ in range(len(grad_list))]) | |
AllGather = AllGatherFunc.apply | |