import os import gc import numpy as np import pandas as pd from tqdm import tqdm import torch from torch import nn from transformers import DistilBertTokenizer import config as CFG from dataset import CLIPDataset, get_transforms from CLIP import CLIPModel from utils import AvgMeter, get_lr def make_train_valid_dfs(): dataframe = pd.read_csv(f"{CFG.captions_path}/captions.csv") dataframe['id'] = dataframe.index #new added max_id = dataframe["id"].max() + 1 if not CFG.debug else 100 image_ids = np.arange(0, max_id) np.random.seed(42) valid_ids = np.random.choice( image_ids, size=int(0.2 * len(image_ids)), replace=False ) train_ids = [id_ for id_ in image_ids if id_ not in valid_ids] train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True) valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True) return train_dataframe, valid_dataframe def build_loaders(dataframe, tokenizer, mode): transforms = get_transforms(mode=mode) dataset = CLIPDataset( dataframe["image"].values, dataframe["caption"].values, tokenizer=tokenizer, transforms=transforms, ) dataloader = torch.utils.data.DataLoader( dataset, batch_size=CFG.batch_size, num_workers=CFG.num_workers, shuffle=True if mode == "train" else False, ) return dataloader def train_epoch(model, train_loader, optimizer, lr_scheduler, step): loss_meter = AvgMeter() tqdm_object = tqdm(train_loader, total=len(train_loader)) for batch in tqdm_object: batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} loss = model(batch) optimizer.zero_grad() loss.backward() optimizer.step() if step == "batch": lr_scheduler.step() count = batch["image"].size(0) loss_meter.update(loss.item(), count) tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) return loss_meter def valid_epoch(model, valid_loader): loss_meter = AvgMeter() tqdm_object = tqdm(valid_loader, total=len(valid_loader)) for batch in tqdm_object: batch = {k: v.to(CFG.device) for k, v in batch.items() if k != "caption"} loss = model(batch) count = batch["image"].size(0) loss_meter.update(loss.item(), count) tqdm_object.set_postfix(valid_loss=loss_meter.avg) return loss_meter def main(): train_df, valid_df = make_train_valid_dfs() tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) train_loader = build_loaders(train_df, tokenizer, mode="train") valid_loader = build_loaders(valid_df, tokenizer, mode="valid") model = CLIPModel().to(CFG.device) optimizer = torch.optim.AdamW( model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay ) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", patience=CFG.patience, factor=CFG.factor ) step = "epoch" best_loss = float('inf') for epoch in range(CFG.epochs): print(f"Epoch: {epoch + 1}") model.train() train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) model.eval() with torch.no_grad(): valid_loss = valid_epoch(model, valid_loader) if valid_loss.avg < best_loss: best_loss = valid_loss.avg torch.save(model.state_dict(), "best2.pt") print("Saved Best Model!") if __name__ == "__main__": main()