|
import os |
|
import cv2 |
|
import gc |
|
import numpy as np |
|
import pandas as pd |
|
import itertools |
|
from tqdm.autonotebook import tqdm |
|
import albumentations as A |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import timm |
|
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer |
|
import os |
|
|
|
class CFG: |
|
debug = False |
|
image_path = "" |
|
captions_path = os.getcwd() |
|
batch_size = 30 |
|
num_workers = 4 |
|
head_lr = 1e-3 |
|
image_encoder_lr = 1e-4 |
|
text_encoder_lr = 1e-5 |
|
weight_decay = 1e-3 |
|
patience = 1 |
|
factor = 0.8 |
|
epochs = 4 |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model_name = 'resnet50' |
|
image_embedding = 2048 |
|
text_encoder_model = "distilbert/distilbert-base-uncased" |
|
text_embedding = 768 |
|
text_tokenizer = "distilbert/distilbert-base-uncased" |
|
max_length = 200 |
|
|
|
pretrained = True |
|
trainable = True |
|
temperature = 1.0 |
|
|
|
|
|
size = 224 |
|
|
|
|
|
num_projection_layers = 1 |
|
projection_dim = 256 |
|
dropout = 0.1 |
|
|
|
class AvgMeter: |
|
def __init__(self, name="Metric"): |
|
self.name = name |
|
self.reset() |
|
|
|
def reset(self): |
|
self.avg, self.sum, self.count = [0] * 3 |
|
|
|
def update(self, val, count=1): |
|
self.count += count |
|
self.sum += val * count |
|
self.avg = self.sum / self.count |
|
|
|
def __repr__(self): |
|
text = f"{self.name}: {self.avg:.4f}" |
|
return text |
|
|
|
def get_lr(optimizer): |
|
for param_group in optimizer.param_groups: |
|
return param_group["lr"] |
|
|
|
class CLIPDataset(torch.utils.data.Dataset): |
|
def __init__(self, image_filenames, captions, tokenizer, transforms): |
|
""" |
|
image_filenames and cpations must have the same length; so, if there are |
|
multiple captions for each image, the image_filenames must have repetitive |
|
file names |
|
""" |
|
|
|
self.image_filenames = image_filenames |
|
self.captions = list(captions) |
|
self.encoded_captions = tokenizer( |
|
list(captions), padding=True, truncation=True, max_length=CFG.max_length |
|
) |
|
self.transforms = transforms |
|
|
|
def __getitem__(self, idx): |
|
item = { |
|
key: torch.tensor(values[idx]) |
|
for key, values in self.encoded_captions.items() |
|
} |
|
|
|
image = cv2.imread(f"{CFG.image_path}/{self.image_filenames[idx]}") |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image = self.transforms(image=image)['image'] |
|
item['image'] = torch.tensor(image).permute(2, 0, 1).float() |
|
item['caption'] = self.captions[idx] |
|
|
|
return item |
|
|
|
|
|
def __len__(self): |
|
return len(self.captions) |
|
|
|
|
|
|
|
def get_transforms(mode="train"): |
|
if mode == "train": |
|
return A.Compose( |
|
[ |
|
A.Resize(CFG.size, CFG.size, always_apply=True), |
|
A.Normalize(max_pixel_value=255.0, always_apply=True), |
|
] |
|
) |
|
else: |
|
return A.Compose( |
|
[ |
|
A.Resize(CFG.size, CFG.size, always_apply=True), |
|
A.Normalize(max_pixel_value=255.0, always_apply=True), |
|
] |
|
) |
|
|
|
class ImageEncoder(nn.Module): |
|
""" |
|
Encode images to a fixed size vector |
|
""" |
|
|
|
def __init__( |
|
self, model_name=CFG.model_name, pretrained=CFG.pretrained, trainable=CFG.trainable |
|
): |
|
super().__init__() |
|
self.model = timm.create_model( |
|
model_name, pretrained, num_classes=0, global_pool="avg" |
|
) |
|
for p in self.model.parameters(): |
|
p.requires_grad = trainable |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class TextEncoder(nn.Module): |
|
def __init__(self, model_name=CFG.text_encoder_model, pretrained=CFG.pretrained, trainable=CFG.trainable): |
|
super().__init__() |
|
if pretrained: |
|
self.model = DistilBertModel.from_pretrained(model_name, use_safetensors=True) |
|
else: |
|
self.model = DistilBertModel(config=DistilBertConfig()) |
|
|
|
for p in self.model.parameters(): |
|
p.requires_grad = trainable |
|
|
|
|
|
self.target_token_idx = 0 |
|
|
|
def forward(self, input_ids, attention_mask): |
|
output = self.model(input_ids=input_ids, attention_mask=attention_mask) |
|
last_hidden_state = output.last_hidden_state |
|
return last_hidden_state[:, self.target_token_idx, :] |
|
|
|
class ProjectionHead(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim, |
|
projection_dim=CFG.projection_dim, |
|
dropout=CFG.dropout |
|
): |
|
super().__init__() |
|
self.projection = nn.Linear(embedding_dim, projection_dim) |
|
self.gelu = nn.GELU() |
|
self.fc = nn.Linear(projection_dim, projection_dim) |
|
self.dropout = nn.Dropout(dropout) |
|
self.layer_norm = nn.LayerNorm(projection_dim) |
|
|
|
def forward(self, x): |
|
projected = self.projection(x) |
|
x = self.gelu(projected) |
|
x = self.fc(x) |
|
x = self.dropout(x) |
|
x = x + projected |
|
x = self.layer_norm(x) |
|
return x |
|
|
|
class CLIPModel(nn.Module): |
|
def __init__( |
|
self, |
|
temperature=CFG.temperature, |
|
image_embedding=CFG.image_embedding, |
|
text_embedding=CFG.text_embedding, |
|
): |
|
super().__init__() |
|
self.image_encoder = ImageEncoder() |
|
self.text_encoder = TextEncoder() |
|
self.image_projection = ProjectionHead(embedding_dim=image_embedding) |
|
self.text_projection = ProjectionHead(embedding_dim=text_embedding) |
|
self.temperature = temperature |
|
|
|
def forward(self, batch): |
|
|
|
image_features = self.image_encoder(batch["image"]) |
|
text_features = self.text_encoder( |
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
|
) |
|
|
|
image_embeddings = self.image_projection(image_features) |
|
text_embeddings = self.text_projection(text_features) |
|
|
|
|
|
logits = (text_embeddings @ image_embeddings.T) / self.temperature |
|
images_similarity = image_embeddings @ image_embeddings.T |
|
texts_similarity = text_embeddings @ text_embeddings.T |
|
targets = F.softmax( |
|
(images_similarity + texts_similarity) / 2 * self.temperature, dim=-1 |
|
) |
|
texts_loss = cross_entropy(logits, targets, reduction='none') |
|
images_loss = cross_entropy(logits.T, targets.T, reduction='none') |
|
loss = (images_loss + texts_loss) / 2.0 |
|
return loss.mean() |
|
|
|
|
|
def cross_entropy(preds, targets, reduction='none'): |
|
log_softmax = nn.LogSoftmax(dim=-1) |
|
loss = (-targets * log_softmax(preds)).sum(1) |
|
if reduction == "none": |
|
return loss |
|
elif reduction == "mean": |
|
return loss.mean() |
|
|
|
def make_train_valid_dfs(): |
|
dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") |
|
dataframe['id'] = dataframe.index |
|
max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 |
|
image_ids = np.arange(0, max_id) |
|
np.random.seed(42) |
|
valid_ids = np.random.choice( |
|
image_ids, size=int(0.2 * len(image_ids)), replace=False |
|
) |
|
train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] |
|
train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) |
|
valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) |
|
return train_dataframe, valid_dataframe |
|
|
|
|
|
def build_loaders(dataframe, tokenizer, mode): |
|
transforms = get_transforms(mode=mode) |
|
dataset = CLIPDataset( |
|
dataframe["image"].values, |
|
dataframe["caption"].values, |
|
tokenizer=tokenizer, |
|
transforms=transforms, |
|
) |
|
dataloader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=CFG.batch_size, |
|
num_workers=CFG.num_workers, |
|
shuffle=True if mode == "train" else False, |
|
) |
|
return dataloader |
|
|
|
def train_epoch(model, train_loader, optimizer, lr_scheduler, step): |
|
loss_meter = AvgMeter() |
|
tqdm_object = tqdm(train_loader, total=len(train_loader)) |
|
for batch in tqdm_object: |
|
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
|
loss = model(batch) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
if step == "batch": |
|
lr_scheduler.step() |
|
|
|
count = batch["image"].size(0) |
|
loss_meter.update(loss.item(), count) |
|
|
|
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) |
|
return loss_meter |
|
|
|
|
|
def valid_epoch(model, valid_loader): |
|
loss_meter = AvgMeter() |
|
|
|
tqdm_object = tqdm(valid_loader, total=len(valid_loader)) |
|
for batch in tqdm_object: |
|
batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} |
|
loss = model(batch) |
|
|
|
count = batch["image"].size(0) |
|
loss_meter.update(loss.item(), count) |
|
|
|
tqdm_object.set_postfix(valid_loss=loss_meter.avg) |
|
return loss_meter |
|
|
|
|
|
def main(): |
|
train_df, valid_df = make_train_valid_dfs() |
|
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
|
train_loader = build_loaders(train_df, tokenizer, mode="train") |
|
valid_loader = build_loaders(valid_df, tokenizer, mode="valid") |
|
|
|
|
|
model = CLIPModel().to(CFG.device) |
|
params = [ |
|
{"params": model.image_encoder.parameters(), "lr": CFG.image_encoder_lr}, |
|
{"params": model.text_encoder.parameters(), "lr": CFG.text_encoder_lr}, |
|
{"params": itertools.chain( |
|
model.image_projection.parameters(), model.text_projection.parameters() |
|
), "lr": CFG.head_lr, "weight_decay": CFG.weight_decay} |
|
] |
|
optimizer = torch.optim.AdamW(params, weight_decay=0.) |
|
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor |
|
) |
|
step = "epoch" |
|
|
|
best_loss = float('inf') |
|
for epoch in range(CFG.epochs): |
|
print(f"Epoch: {epoch + 1}") |
|
model.train() |
|
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) |
|
model.eval() |
|
with torch.no_grad(): |
|
valid_loss = valid_epoch(model, valid_loader) |
|
|
|
if valid_loss.avg < best_loss: |
|
best_loss = valid_loss.avg |
|
torch.save(model.state_dict(), "best.pt") |
|
print("Saved Best Model!") |
|
|
|
lr_scheduler.step(valid_loss.avg) |
|
|
|
if __name__ == "__main__": |
|
main() |