|
import cv2 |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration |
|
import torch |
|
from PIL import Image |
|
import faiss |
|
import logging |
|
import gradio as gr |
|
import tempfile |
|
import os |
|
import shutil |
|
from tqdm.auto import tqdm |
|
from pathlib import Path |
|
from typing import List, Dict, Tuple, Optional |
|
import gc |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "./model_cache" |
|
os.environ["HF_HOME"] = "./model_cache" |
|
os.makedirs("./model_cache", exist_ok=True) |
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
class VideoProcessor: |
|
def __init__(self): |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logging.info(f"Using device: {self.device}") |
|
|
|
|
|
self._load_models() |
|
|
|
|
|
self.frame_interval = 30 |
|
self.max_frames = 50 |
|
self.target_size = (224, 224) |
|
self.batch_size = 4 if torch.cuda.is_available() else 2 |
|
|
|
def _load_models(self): |
|
"""Load models with optimizations and proper configurations""" |
|
try: |
|
logging.info("Loading CLIP model...") |
|
self.clip_model = CLIPModel.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
cache_dir="./model_cache" |
|
).to(self.device) |
|
self.clip_processor = CLIPProcessor.from_pretrained( |
|
"openai/clip-vit-base-patch32", |
|
cache_dir="./model_cache" |
|
) |
|
|
|
logging.info("Loading BLIP2 model...") |
|
model_name = "Salesforce/blip2-opt-2.7b" |
|
|
|
|
|
self.blip_processor = Blip2Processor.from_pretrained( |
|
model_name, |
|
cache_dir="./model_cache" |
|
) |
|
|
|
self.blip_model = Blip2ForConditionalGeneration.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None, |
|
cache_dir="./model_cache", |
|
low_cpu_mem_usage=True |
|
).to(self.device) |
|
|
|
|
|
self.clip_model.eval() |
|
self.blip_model.eval() |
|
|
|
logging.info("Models loaded successfully!") |
|
except Exception as e: |
|
logging.error(f"Error loading models: {str(e)}") |
|
raise |
|
|
|
def _preprocess_frame(self, frame: np.ndarray) -> Image.Image: |
|
"""Preprocess a single frame""" |
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
return Image.fromarray(rgb_frame).resize(self.target_size, Image.LANCZOS) |
|
|
|
@torch.no_grad() |
|
def process_frame_batch(self, frames: List[np.ndarray]) -> Tuple[Optional[np.ndarray], Optional[List[str]]]: |
|
"""Process a batch of frames efficiently""" |
|
try: |
|
|
|
pil_frames = [self._preprocess_frame(f) for f in frames] |
|
|
|
|
|
clip_inputs = self.clip_processor( |
|
images=pil_frames, |
|
return_tensors="pt", |
|
padding=True |
|
).to(self.device) |
|
|
|
if self.device.type == "cuda": |
|
clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()} |
|
features = self.clip_model.get_image_features(**clip_inputs) |
|
|
|
|
|
blip_inputs = self.blip_processor( |
|
images=pil_frames, |
|
return_tensors="pt", |
|
padding=True |
|
).to(self.device) |
|
|
|
if self.device.type == "cuda": |
|
blip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in blip_inputs.items()} |
|
|
|
|
|
captions = self.blip_model.generate( |
|
**blip_inputs, |
|
max_length=30, |
|
min_length=10, |
|
num_beams=5, |
|
length_penalty=1.0, |
|
temperature=0.7, |
|
do_sample=False |
|
) |
|
|
|
captions = [self.blip_processor.decode(c, skip_special_tokens=True) for c in captions] |
|
|
|
|
|
if self.device.type == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
return features.cpu().numpy(), captions |
|
|
|
except Exception as e: |
|
logging.error(f"Error in batch processing: {str(e)}") |
|
return None, None |
|
|
|
def process_video(self, video_path: str, progress: gr.Progress) -> Tuple[Optional[faiss.Index], Optional[List[Dict]], str]: |
|
"""Process video with batching and progress updates""" |
|
cap = None |
|
try: |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
raise ValueError(f"Could not open video file: {video_path}") |
|
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
|
|
|
|
frames_to_process = min(self.max_frames, total_frames // self.frame_interval) |
|
progress(0, desc="Initializing video processing...") |
|
|
|
features_list = [] |
|
frame_data = [] |
|
current_batch = [] |
|
batch_positions = [] |
|
|
|
frame_count = 0 |
|
processed_count = 0 |
|
|
|
while processed_count < frames_to_process: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % self.frame_interval == 0: |
|
current_batch.append(frame) |
|
batch_positions.append(frame_count) |
|
|
|
|
|
if len(current_batch) == self.batch_size or frame_count == total_frames - 1: |
|
progress(processed_count / frames_to_process, |
|
desc=f"Processing frames... {processed_count}/{frames_to_process}") |
|
|
|
features, captions = self.process_frame_batch(current_batch) |
|
|
|
if features is not None and captions is not None: |
|
for i, (feat, cap_text) in enumerate(zip(features, captions)): |
|
features_list.append(feat) |
|
frame_data.append({ |
|
'frame_number': batch_positions[i], |
|
'timestamp': batch_positions[i] / fps, |
|
'caption': cap_text |
|
}) |
|
|
|
processed_count += len(current_batch) |
|
current_batch = [] |
|
batch_positions = [] |
|
|
|
frame_count += 1 |
|
|
|
|
|
if features_list: |
|
features_array = np.vstack(features_list) |
|
frame_index = faiss.IndexFlatL2(features_array.shape[1]) |
|
frame_index.add(features_array) |
|
return frame_index, frame_data, "Video processed successfully!" |
|
else: |
|
return None, None, "No frames were processed successfully." |
|
|
|
except Exception as e: |
|
logging.error(f"Error processing video: {str(e)}") |
|
return None, None, f"Error processing video: {str(e)}" |
|
|
|
finally: |
|
if cap is not None: |
|
cap.release() |
|
gc.collect() |
|
if self.device.type == "cuda": |
|
torch.cuda.empty_cache() |
|
|
|
class VideoQAInterface: |
|
def __init__(self): |
|
self.processor = VideoProcessor() |
|
self.frame_index = None |
|
self.frame_data = None |
|
self.processed = False |
|
self.current_video_path = None |
|
self.temp_dir = tempfile.mkdtemp() |
|
logging.info(f"Initialized temp directory: {self.temp_dir}") |
|
|
|
def __del__(self): |
|
"""Cleanup temporary files""" |
|
try: |
|
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): |
|
shutil.rmtree(self.temp_dir) |
|
logging.info(f"Cleaned up temp directory: {self.temp_dir}") |
|
except Exception as e: |
|
logging.error(f"Error cleaning up temp directory: {str(e)}") |
|
|
|
def process_video(self, video_file, progress=gr.Progress()): |
|
"""Process video with progress tracking""" |
|
if video_file is None: |
|
return "Please upload a video first." |
|
|
|
try: |
|
|
|
temp_video_path = os.path.join(self.temp_dir, "input_video.mp4") |
|
shutil.copy2(video_file.name, temp_video_path) |
|
self.current_video_path = temp_video_path |
|
logging.info(f"Saved video to: {self.current_video_path}") |
|
|
|
progress(0, desc="Starting video processing...") |
|
self.frame_index, self.frame_data, message = self.processor.process_video( |
|
self.current_video_path, progress |
|
) |
|
|
|
if self.frame_index is not None: |
|
self.processed = True |
|
return "Video processed successfully! You can now ask questions." |
|
else: |
|
self.processed = False |
|
return message |
|
|
|
except Exception as e: |
|
self.processed = False |
|
logging.error(f"Error processing video: {str(e)}") |
|
return f"Error processing video: {str(e)}" |
|
|
|
@torch.no_grad() |
|
def answer_question(self, query): |
|
"""Answer questions about the video""" |
|
if not self.processed or self.current_video_path is None: |
|
return None, "Please process a video first." |
|
|
|
try: |
|
|
|
inputs = self.processor.clip_processor(text=[query], return_tensors="pt").to(self.processor.device) |
|
query_features = self.processor.clip_model.get_text_features(**inputs) |
|
|
|
|
|
k = 4 |
|
D, I = self.frame_index.search(query_features.cpu().numpy(), k) |
|
|
|
results = [] |
|
for distance, idx in zip(D[0], I[0]): |
|
frame_info = self.frame_data[idx].copy() |
|
frame_info['relevance'] = float(1 / (1 + distance)) |
|
results.append(frame_info) |
|
|
|
|
|
descriptions = [] |
|
frames = [] |
|
|
|
cap = cv2.VideoCapture(self.current_video_path) |
|
try: |
|
for result in results: |
|
frame_number = result['frame_number'] |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
ret, frame = cap.read() |
|
|
|
if ret: |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
desc = f"Timestamp: {result['timestamp']:.2f}s\n" |
|
desc += f"Scene Description: {result['caption']}\n" |
|
desc += f"Relevance Score: {result['relevance']:.2f}" |
|
descriptions.append(desc) |
|
finally: |
|
cap.release() |
|
|
|
if not frames: |
|
return None, "No relevant frames found." |
|
|
|
combined_desc = "\n\nFrame Analysis:\n\n" |
|
for i, desc in enumerate(descriptions, 1): |
|
combined_desc += f"Frame {i}:\n{desc}\n\n" |
|
|
|
return frames, combined_desc |
|
|
|
except Exception as e: |
|
logging.error(f"Error answering question: {str(e)}") |
|
return None, f"Error answering question: {str(e)}" |
|
|
|
def create_interface(self): |
|
"""Create Gradio interface""" |
|
with gr.Blocks(title="Advanced Video Question Answering") as interface: |
|
gr.Markdown("# Advanced Video Question Answering") |
|
gr.Markdown("Upload a video and ask questions about any aspect of its content!") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
video_input = gr.File( |
|
label="Upload Video", |
|
file_types=["video"] |
|
) |
|
status = gr.Textbox(label="Status", interactive=False) |
|
process_btn = gr.Button("Process Video") |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Ask about the video", |
|
placeholder="What's happening in the video?" |
|
) |
|
query_btn = gr.Button("Search") |
|
|
|
gallery = gr.Gallery( |
|
label="Retrieved Frames", |
|
show_label=True, |
|
columns=[2], |
|
rows=[2] |
|
) |
|
|
|
descriptions = gr.Textbox( |
|
label="Analysis", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
|
|
process_btn.click( |
|
fn=self.process_video, |
|
inputs=[video_input], |
|
outputs=[status] |
|
) |
|
|
|
query_btn.click( |
|
fn=self.answer_question, |
|
inputs=[query_input], |
|
outputs=[gallery, descriptions] |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = VideoQAInterface() |
|
interface = app.create_interface() |
|
|
|
if __name__ == "__main__": |
|
interface.launch( |
|
server_name="0.0.0.0", |
|
share=False, |
|
show_error=True |
|
) |