|
import csv |
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader |
|
|
|
from . import video_transforms |
|
from .utils import center_crop_arr |
|
|
|
|
|
def get_transforms_video(resolution=256): |
|
transform_video = transforms.Compose( |
|
[ |
|
video_transforms.ToTensorVideo(), |
|
video_transforms.RandomHorizontalFlipVideo(), |
|
video_transforms.UCFCenterCropVideo(resolution), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
] |
|
) |
|
return transform_video |
|
|
|
|
|
def get_transforms_image(image_size=256): |
|
transform = transforms.Compose( |
|
[ |
|
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
|
] |
|
) |
|
return transform |
|
|
|
|
|
class DatasetFromCSV(torch.utils.data.Dataset): |
|
"""load video according to the csv file. |
|
|
|
Args: |
|
target_video_len (int): the number of video frames will be load. |
|
align_transform (callable): Align different videos in a specified size. |
|
temporal_sample (callable): Sample the target length of a video. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
csv_path, |
|
num_frames=16, |
|
frame_interval=1, |
|
transform=None, |
|
root=None, |
|
): |
|
self.csv_path = csv_path |
|
with open(csv_path, "r") as f: |
|
reader = csv.reader(f) |
|
self.samples = list(reader) |
|
|
|
ext = self.samples[0][0].split(".")[-1] |
|
if ext.lower() in ("mp4", "avi", "mov", "mkv"): |
|
self.is_video = True |
|
else: |
|
assert f".{ext.lower()}" in IMG_EXTENSIONS, f"Unsupported file format: {ext}" |
|
self.is_video = False |
|
|
|
self.transform = transform |
|
|
|
self.num_frames = num_frames |
|
self.frame_interval = frame_interval |
|
self.temporal_sample = video_transforms.TemporalRandomCrop(num_frames * frame_interval) |
|
self.root = root |
|
|
|
def getitem(self, index): |
|
sample = self.samples[index] |
|
path = sample[0] |
|
if self.root: |
|
path = os.path.join(self.root, path) |
|
text = sample[1] |
|
|
|
if self.is_video: |
|
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") |
|
total_frames = len(vframes) |
|
|
|
|
|
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
|
assert ( |
|
end_frame_ind - start_frame_ind >= self.num_frames |
|
), f"{path} with index {index} has not enough frames." |
|
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int) |
|
|
|
video = vframes[frame_indice] |
|
video = self.transform(video) |
|
else: |
|
image = pil_loader(path) |
|
image = self.transform(image) |
|
video = image.unsqueeze(0).repeat(self.num_frames, 1, 1, 1) |
|
|
|
|
|
video = video.permute(1, 0, 2, 3) |
|
|
|
return {"video": video, "text": text} |
|
|
|
def __getitem__(self, index): |
|
for _ in range(10): |
|
try: |
|
return self.getitem(index) |
|
except Exception as e: |
|
print(e) |
|
index = np.random.randint(len(self)) |
|
raise RuntimeError("Too many bad data.") |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|