|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
import config as CFG |
|
from modules import ImageEncoder, TextEncoder, ProjectionHead |
|
|
|
|
|
class CLIPModel(nn.Module): |
|
def __init__( |
|
self, |
|
temperature=CFG.temperature, |
|
image_embedding=CFG.image_embedding, |
|
text_embedding=CFG.text_embedding, |
|
): |
|
super().__init__() |
|
self.image_encoder = ImageEncoder() |
|
self.text_encoder = TextEncoder() |
|
self.image_projection = ProjectionHead(embedding_dim=image_embedding) |
|
self.text_projection = ProjectionHead(embedding_dim=text_embedding) |
|
self.temperature = temperature |
|
|
|
def forward(self, batch): |
|
|
|
image_features = self.image_encoder(batch["image"]) |
|
text_features = self.text_encoder( |
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
|
) |
|
|
|
image_embeddings = self.image_projection(image_features) |
|
text_embeddings = self.text_projection(text_features) |
|
|
|
|
|
logits = (text_embeddings @ image_embeddings.T) / self.temperature |
|
images_similarity = image_embeddings @ image_embeddings.T |
|
texts_similarity = text_embeddings @ text_embeddings.T |
|
targets = F.softmax( |
|
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 |
|
) |
|
texts_loss = cross_entropy(logits, targets, reduction='none') |
|
images_loss = cross_entropy(logits.T, targets.T, reduction='none') |
|
loss = (images_loss + texts_loss) / 2.0 |
|
return loss.mean() |
|
|
|
|
|
def cross_entropy(preds, targets, reduction='none'): |
|
log_softmax = nn.LogSoftmax(dim=-1) |
|
loss = (-targets * log_softmax(preds)).sum(1) |
|
if reduction == "none": |
|
return loss |
|
elif reduction == "mean": |
|
return loss.mean() |
|
|
|
if __name__ == '__main__': |
|
images = torch.randn(8, 3, 224, 224) |
|
input_ids = torch.randint(5, 300, size=(8, 25)) |
|
attention_mask = torch.ones(8, 25) |
|
batch = { |
|
'image': images, |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask |
|
} |
|
|
|
CLIP = CLIPModel() |
|
loss = CLIP(batch) |
|
print("") |