|
import cv2 |
|
import numpy as np |
|
from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditionalGeneration |
|
import torch |
|
from PIL import Image |
|
import faiss |
|
from typing import List, Dict, Tuple |
|
import logging |
|
import gradio as gr |
|
import tempfile |
|
import os |
|
import shutil |
|
from tqdm import tqdm |
|
import math |
|
|
|
class VideoRAGTool: |
|
def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32", |
|
blip_model_name: str = "Salesforce/blip-image-captioning-base"): |
|
"""Initialize with performance optimizations.""" |
|
|
|
self.logger = self.setup_logger() |
|
|
|
self.logger.info("Initializing VideoRAGTool...") |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.logger.info(f"Using device: {self.device}") |
|
|
|
|
|
self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device) |
|
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_name) |
|
self.blip_processor = BlipProcessor.from_pretrained(blip_model_name) |
|
self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name).to(self.device) |
|
|
|
|
|
self.clip_model.eval() |
|
self.blip_model.eval() |
|
|
|
|
|
self.batch_size = 4 |
|
|
|
self.frame_index = None |
|
self.frame_data = [] |
|
|
|
def setup_logger(self) -> logging.Logger: |
|
"""Set up logging configuration.""" |
|
logger = logging.getLogger('VideoRAGTool') |
|
|
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
logger.setLevel(logging.INFO) |
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
return logger |
|
|
|
@torch.no_grad() |
|
def generate_caption(self, image: Image.Image) -> str: |
|
"""Optimized caption generation.""" |
|
try: |
|
inputs = self.blip_processor(image, return_tensors="pt").to(self.device) |
|
out = self.blip_model.generate(**inputs, max_length=30, num_beams=2) |
|
caption = self.blip_processor.decode(out[0], skip_special_tokens=True) |
|
return caption |
|
except Exception as e: |
|
self.logger.error(f"Error generating caption: {str(e)}") |
|
return "Caption generation failed" |
|
|
|
def get_video_info(self, video_path: str) -> Tuple[int, float]: |
|
"""Get video frame count and FPS.""" |
|
cap = cv2.VideoCapture(video_path) |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cap.get(cv2.CAP_PROP_FPS) |
|
cap.release() |
|
return total_frames, fps |
|
|
|
def preprocess_frame(self, frame: np.ndarray, target_size: Tuple[int, int] = (224, 224)) -> Image.Image: |
|
"""Preprocess frame with resizing for efficiency.""" |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
image = Image.fromarray(frame_rgb) |
|
return image.resize(target_size, Image.LANCZOS) |
|
|
|
@torch.no_grad() |
|
def process_batch(self, frames: List[Image.Image]) -> Tuple[np.ndarray, List[str]]: |
|
"""Process a batch of frames efficiently.""" |
|
try: |
|
|
|
clip_inputs = self.clip_processor(images=frames, return_tensors="pt", padding=True).to(self.device) |
|
image_features = self.clip_model.get_image_features(**clip_inputs) |
|
|
|
|
|
captions = [] |
|
for frame in frames: |
|
caption = self.generate_caption(frame) |
|
captions.append(caption) |
|
|
|
return image_features.cpu().numpy(), captions |
|
except Exception as e: |
|
self.logger.error(f"Error processing batch: {str(e)}") |
|
raise |
|
|
|
def process_video(self, video_path: str, frame_interval: int = 30) -> None: |
|
"""Optimized video processing with batching and progress tracking.""" |
|
self.logger.info(f"Processing video: {video_path}") |
|
|
|
try: |
|
total_frames, fps = self.get_video_info(video_path) |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
|
|
frames_to_process = total_frames // frame_interval |
|
total_batches = math.ceil(frames_to_process / self.batch_size) |
|
|
|
current_batch = [] |
|
features_list = [] |
|
frame_count = 0 |
|
|
|
with tqdm(total=frames_to_process, desc="Processing frames") as pbar: |
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
if frame_count % frame_interval == 0: |
|
|
|
processed_frame = self.preprocess_frame(frame) |
|
current_batch.append(processed_frame) |
|
|
|
|
|
if len(current_batch) == self.batch_size: |
|
batch_features, batch_captions = self.process_batch(current_batch) |
|
|
|
|
|
for i, (features, caption) in enumerate(zip(batch_features, batch_captions)): |
|
batch_frame_number = frame_count - (self.batch_size - i - 1) * frame_interval |
|
self.frame_data.append({ |
|
'frame_number': batch_frame_number, |
|
'timestamp': batch_frame_number / fps, |
|
'caption': caption |
|
}) |
|
features_list.append(features) |
|
|
|
current_batch = [] |
|
pbar.update(self.batch_size) |
|
|
|
frame_count += 1 |
|
|
|
|
|
if current_batch: |
|
batch_features, batch_captions = self.process_batch(current_batch) |
|
for i, (features, caption) in enumerate(zip(batch_features, batch_captions)): |
|
batch_frame_number = frame_count - (len(current_batch) - i - 1) * frame_interval |
|
self.frame_data.append({ |
|
'frame_number': batch_frame_number, |
|
'timestamp': batch_frame_number / fps, |
|
'caption': caption |
|
}) |
|
features_list.append(features) |
|
|
|
cap.release() |
|
|
|
if not features_list: |
|
raise ValueError("No frames were processed from the video") |
|
|
|
|
|
features_array = np.vstack(features_list) |
|
self.frame_index = faiss.IndexFlatL2(features_array.shape[1]) |
|
self.frame_index.add(features_array) |
|
|
|
self.logger.info(f"Processed {len(self.frame_data)} frames from video") |
|
|
|
except Exception as e: |
|
self.logger.error(f"Error processing video: {str(e)}") |
|
raise |
|
|
|
def query_video(self, query_text: str, k: int = 5) -> List[Dict]: |
|
"""Query the video using natural language and return relevant frames.""" |
|
self.logger.info(f"Processing query: {query_text}") |
|
|
|
try: |
|
inputs = self.clip_processor(text=[query_text], return_tensors="pt").to(self.device) |
|
text_features = self.clip_model.get_text_features(**inputs) |
|
|
|
distances, indices = self.frame_index.search( |
|
text_features.cpu().detach().numpy(), |
|
k |
|
) |
|
|
|
results = [] |
|
for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): |
|
frame_info = self.frame_data[idx].copy() |
|
frame_info['relevance_score'] = float(1 / (1 + distance)) |
|
results.append(frame_info) |
|
|
|
return results |
|
except Exception as e: |
|
self.logger.error(f"Error querying video: {str(e)}") |
|
raise |
|
|
|
class VideoRAGApp: |
|
def __init__(self): |
|
self.rag_tool = VideoRAGTool() |
|
self.current_video_path = None |
|
self.processed = False |
|
self.temp_dir = tempfile.mkdtemp() |
|
|
|
def __del__(self): |
|
"""Cleanup temporary files on deletion""" |
|
if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir): |
|
shutil.rmtree(self.temp_dir, ignore_errors=True) |
|
|
|
def process_video(self, video_file): |
|
"""Process uploaded video and return status message""" |
|
try: |
|
if video_file is None: |
|
return "Please upload a video first." |
|
|
|
video_path = video_file.name |
|
temp_video_path = os.path.join(self.temp_dir, "current_video.mp4") |
|
shutil.copy2(video_path, temp_video_path) |
|
|
|
self.current_video_path = temp_video_path |
|
|
|
self.rag_tool.process_video(self.current_video_path) |
|
self.processed = True |
|
return "Video processed successfully! You can now ask questions about the video." |
|
|
|
except Exception as e: |
|
self.processed = False |
|
return f"Error processing video: {str(e)}" |
|
|
|
def query_video(self, query_text): |
|
"""Query the video and return relevant frames with descriptions""" |
|
if not self.processed: |
|
return None, "Please process a video first." |
|
|
|
try: |
|
results = self.rag_tool.query_video(query_text, k=4) |
|
|
|
frames = [] |
|
descriptions = [] |
|
|
|
cap = cv2.VideoCapture(self.current_video_path) |
|
|
|
for result in results: |
|
frame_number = result['frame_number'] |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
ret, frame = cap.read() |
|
|
|
if ret: |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
frames.append(Image.fromarray(frame_rgb)) |
|
|
|
description = f"Timestamp: {result['timestamp']:.2f}s\n" |
|
description += f"Scene Description: {result['caption']}\n" |
|
description += f"Relevance Score: {result['relevance_score']:.2f}" |
|
descriptions.append(description) |
|
|
|
cap.release() |
|
|
|
|
|
combined_description = "\n\nFrame Analysis:\n\n" |
|
for i, desc in enumerate(descriptions, 1): |
|
combined_description += f"Frame {i}:\n{desc}\n\n" |
|
|
|
return frames, combined_description |
|
|
|
except Exception as e: |
|
return None, f"Error querying video: {str(e)}" |
|
|
|
def create_interface(self): |
|
"""Create and return Gradio interface""" |
|
with gr.Blocks(title="Video Chat RAG") as interface: |
|
gr.Markdown("# Video Chat RAG") |
|
gr.Markdown("Upload a video and ask questions about its content!") |
|
|
|
with gr.Row(): |
|
video_input = gr.File( |
|
label="Upload Video", |
|
file_types=["video"], |
|
) |
|
process_button = gr.Button("Process Video") |
|
|
|
status_output = gr.Textbox( |
|
label="Status", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
query_input = gr.Textbox( |
|
label="Ask about the video", |
|
placeholder="What's happening in the video?" |
|
) |
|
query_button = gr.Button("Search") |
|
|
|
with gr.Row(): |
|
gallery = gr.Gallery( |
|
label="Retrieved Frames", |
|
show_label=True, |
|
elem_id="gallery", |
|
columns=[2], |
|
rows=[2], |
|
height="auto" |
|
) |
|
|
|
descriptions = gr.Textbox( |
|
label="Scene Descriptions", |
|
interactive=False, |
|
lines=10 |
|
) |
|
|
|
process_button.click( |
|
fn=self.process_video, |
|
inputs=[video_input], |
|
outputs=[status_output] |
|
) |
|
|
|
query_button.click( |
|
fn=self.query_video, |
|
inputs=[query_input], |
|
outputs=[gallery, descriptions] |
|
) |
|
|
|
return interface |
|
|
|
|
|
app = VideoRAGApp() |
|
interface = app.create_interface() |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |