capradeepgujaran commited on
Commit
5f52218
1 Parent(s): 8ad7e0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -99
app.py CHANGED
@@ -4,7 +4,6 @@ from transformers import CLIPProcessor, CLIPModel, BlipProcessor, BlipForConditi
4
  import torch
5
  from PIL import Image
6
  import faiss
7
- import pickle
8
  from typing import List, Dict, Tuple
9
  import logging
10
  import gradio as gr
@@ -12,16 +11,18 @@ import tempfile
12
  import os
13
  import shutil
14
  from tqdm import tqdm
15
- import torch.nn as nn
16
  import math
17
- from concurrent.futures import ThreadPoolExecutor
18
- import numpy as np
19
 
20
  class VideoRAGTool:
21
  def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32",
22
  blip_model_name: str = "Salesforce/blip-image-captioning-base"):
23
  """Initialize with performance optimizations."""
 
 
 
 
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
25
 
26
  # Initialize models with optimization flags
27
  self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
@@ -34,18 +35,37 @@ class VideoRAGTool:
34
  self.blip_model.eval()
35
 
36
  # Batch processing settings
37
- self.batch_size = 8 # Adjust based on your GPU memory
38
 
39
  self.frame_index = None
40
  self.frame_data = []
41
- self.logger = self._setup_logger()
42
 
43
- @torch.no_grad() # Disable gradient computation for inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate_caption(self, image: Image.Image) -> str:
45
  """Optimized caption generation."""
46
- inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
47
- out = self.blip_model.generate(**inputs, max_length=30, num_beams=2)
48
- return self.blip_processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
49
 
50
  def get_video_info(self, video_path: str) -> Tuple[int, float]:
51
  """Get video frame count and FPS."""
@@ -64,110 +84,119 @@ class VideoRAGTool:
64
  @torch.no_grad()
65
  def process_batch(self, frames: List[Image.Image]) -> Tuple[np.ndarray, List[str]]:
66
  """Process a batch of frames efficiently."""
67
- # CLIP processing
68
- clip_inputs = self.clip_processor(images=frames, return_tensors="pt", padding=True).to(self.device)
69
- image_features = self.clip_model.get_image_features(**clip_inputs)
70
-
71
- # BLIP processing
72
- captions = []
73
- blip_inputs = self.blip_processor(images=frames, return_tensors="pt", padding=True).to(self.device)
74
- out = self.blip_model.generate(**blip_inputs, max_length=30, num_beams=2)
75
-
76
- for o in out:
77
- caption = self.blip_processor.decode(o, skip_special_tokens=True)
78
- captions.append(caption)
79
-
80
- return image_features.cpu().numpy(), captions
 
81
 
82
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
83
  """Optimized video processing with batching and progress tracking."""
84
  self.logger.info(f"Processing video: {video_path}")
85
 
86
- total_frames, fps = self.get_video_info(video_path)
87
- cap = cv2.VideoCapture(video_path)
88
-
89
- # Calculate total batches for progress bar
90
- frames_to_process = total_frames // frame_interval
91
- total_batches = math.ceil(frames_to_process / self.batch_size)
92
-
93
- current_batch = []
94
- features_list = []
95
- frame_count = 0
96
-
97
- with tqdm(total=frames_to_process, desc="Processing frames") as pbar:
98
- while cap.isOpened():
99
- ret, frame = cap.read()
100
- if not ret:
101
- break
102
-
103
- if frame_count % frame_interval == 0:
104
- # Preprocess frame
105
- processed_frame = self.preprocess_frame(frame)
106
- current_batch.append(processed_frame)
107
 
108
- # Process batch when it reaches batch_size
109
- if len(current_batch) == self.batch_size:
110
- batch_features, batch_captions = self.process_batch(current_batch)
111
-
112
- # Store results
113
- for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
114
- batch_frame_number = frame_count - (self.batch_size - i - 1) * frame_interval
115
- self.frame_data.append({
116
- 'frame_number': batch_frame_number,
117
- 'timestamp': batch_frame_number / fps,
118
- 'caption': caption
119
- })
120
- features_list.append(features)
121
 
122
- current_batch = []
123
- pbar.update(self.batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- frame_count += 1
 
 
 
 
 
 
 
 
 
 
126
 
127
- # Process remaining frames
128
- if current_batch:
129
- batch_features, batch_captions = self.process_batch(current_batch)
130
- for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
131
- batch_frame_number = frame_count - (len(current_batch) - i - 1) * frame_interval
132
- self.frame_data.append({
133
- 'frame_number': batch_frame_number,
134
- 'timestamp': batch_frame_number / fps,
135
- 'caption': caption
136
- })
137
- features_list.append(features)
138
-
139
- cap.release()
140
-
141
- if not features_list:
142
- raise ValueError("No frames were processed from the video")
143
-
144
- # Create FAISS index
145
- features_array = np.vstack(features_list)
146
- self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
147
- self.frame_index.add(features_array)
148
-
149
- self.logger.info(f"Processed {len(self.frame_data)} frames from video")
150
-
151
 
152
  def query_video(self, query_text: str, k: int = 5) -> List[Dict]:
153
  """Query the video using natural language and return relevant frames."""
154
  self.logger.info(f"Processing query: {query_text}")
155
 
156
- inputs = self.clip_processor(text=[query_text], return_tensors="pt").to(self.device)
157
- text_features = self.clip_model.get_text_features(**inputs)
158
-
159
- distances, indices = self.frame_index.search(
160
- text_features.cpu().detach().numpy(),
161
- k
162
- )
163
-
164
- results = []
165
- for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
166
- frame_info = self.frame_data[idx].copy()
167
- frame_info['relevance_score'] = float(1 / (1 + distance))
168
- results.append(frame_info)
169
 
170
- return results
 
 
 
 
 
 
 
 
 
171
 
172
  class VideoRAGApp:
173
  def __init__(self):
 
4
  import torch
5
  from PIL import Image
6
  import faiss
 
7
  from typing import List, Dict, Tuple
8
  import logging
9
  import gradio as gr
 
11
  import os
12
  import shutil
13
  from tqdm import tqdm
 
14
  import math
 
 
15
 
16
  class VideoRAGTool:
17
  def __init__(self, clip_model_name: str = "openai/clip-vit-base-patch32",
18
  blip_model_name: str = "Salesforce/blip-image-captioning-base"):
19
  """Initialize with performance optimizations."""
20
+ # Setup logger first to avoid the attribute error
21
+ self.logger = self.setup_logger()
22
+
23
+ self.logger.info("Initializing VideoRAGTool...")
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ self.logger.info(f"Using device: {self.device}")
26
 
27
  # Initialize models with optimization flags
28
  self.clip_model = CLIPModel.from_pretrained(clip_model_name).to(self.device)
 
35
  self.blip_model.eval()
36
 
37
  # Batch processing settings
38
+ self.batch_size = 4 # Reduced batch size for better memory management
39
 
40
  self.frame_index = None
41
  self.frame_data = []
 
42
 
43
+ def setup_logger(self) -> logging.Logger:
44
+ """Set up logging configuration."""
45
+ logger = logging.getLogger('VideoRAGTool')
46
+
47
+ # Clear any existing handlers
48
+ if logger.handlers:
49
+ logger.handlers.clear()
50
+
51
+ logger.setLevel(logging.INFO)
52
+ handler = logging.StreamHandler()
53
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
54
+ handler.setFormatter(formatter)
55
+ logger.addHandler(handler)
56
+ return logger
57
+
58
+ @torch.no_grad()
59
  def generate_caption(self, image: Image.Image) -> str:
60
  """Optimized caption generation."""
61
+ try:
62
+ inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
63
+ out = self.blip_model.generate(**inputs, max_length=30, num_beams=2)
64
+ caption = self.blip_processor.decode(out[0], skip_special_tokens=True)
65
+ return caption
66
+ except Exception as e:
67
+ self.logger.error(f"Error generating caption: {str(e)}")
68
+ return "Caption generation failed"
69
 
70
  def get_video_info(self, video_path: str) -> Tuple[int, float]:
71
  """Get video frame count and FPS."""
 
84
  @torch.no_grad()
85
  def process_batch(self, frames: List[Image.Image]) -> Tuple[np.ndarray, List[str]]:
86
  """Process a batch of frames efficiently."""
87
+ try:
88
+ # CLIP processing
89
+ clip_inputs = self.clip_processor(images=frames, return_tensors="pt", padding=True).to(self.device)
90
+ image_features = self.clip_model.get_image_features(**clip_inputs)
91
+
92
+ # BLIP processing
93
+ captions = []
94
+ for frame in frames:
95
+ caption = self.generate_caption(frame)
96
+ captions.append(caption)
97
+
98
+ return image_features.cpu().numpy(), captions
99
+ except Exception as e:
100
+ self.logger.error(f"Error processing batch: {str(e)}")
101
+ raise
102
 
103
  def process_video(self, video_path: str, frame_interval: int = 30) -> None:
104
  """Optimized video processing with batching and progress tracking."""
105
  self.logger.info(f"Processing video: {video_path}")
106
 
107
+ try:
108
+ total_frames, fps = self.get_video_info(video_path)
109
+ cap = cv2.VideoCapture(video_path)
110
+
111
+ # Calculate total batches for progress bar
112
+ frames_to_process = total_frames // frame_interval
113
+ total_batches = math.ceil(frames_to_process / self.batch_size)
114
+
115
+ current_batch = []
116
+ features_list = []
117
+ frame_count = 0
118
+
119
+ with tqdm(total=frames_to_process, desc="Processing frames") as pbar:
120
+ while cap.isOpened():
121
+ ret, frame = cap.read()
122
+ if not ret:
123
+ break
 
 
 
 
124
 
125
+ if frame_count % frame_interval == 0:
126
+ # Preprocess frame
127
+ processed_frame = self.preprocess_frame(frame)
128
+ current_batch.append(processed_frame)
 
 
 
 
 
 
 
 
 
129
 
130
+ # Process batch when it reaches batch_size
131
+ if len(current_batch) == self.batch_size:
132
+ batch_features, batch_captions = self.process_batch(current_batch)
133
+
134
+ # Store results
135
+ for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
136
+ batch_frame_number = frame_count - (self.batch_size - i - 1) * frame_interval
137
+ self.frame_data.append({
138
+ 'frame_number': batch_frame_number,
139
+ 'timestamp': batch_frame_number / fps,
140
+ 'caption': caption
141
+ })
142
+ features_list.append(features)
143
+
144
+ current_batch = []
145
+ pbar.update(self.batch_size)
146
+
147
+ frame_count += 1
148
 
149
+ # Process remaining frames
150
+ if current_batch:
151
+ batch_features, batch_captions = self.process_batch(current_batch)
152
+ for i, (features, caption) in enumerate(zip(batch_features, batch_captions)):
153
+ batch_frame_number = frame_count - (len(current_batch) - i - 1) * frame_interval
154
+ self.frame_data.append({
155
+ 'frame_number': batch_frame_number,
156
+ 'timestamp': batch_frame_number / fps,
157
+ 'caption': caption
158
+ })
159
+ features_list.append(features)
160
 
161
+ cap.release()
162
+
163
+ if not features_list:
164
+ raise ValueError("No frames were processed from the video")
165
+
166
+ # Create FAISS index
167
+ features_array = np.vstack(features_list)
168
+ self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
169
+ self.frame_index.add(features_array)
170
+
171
+ self.logger.info(f"Processed {len(self.frame_data)} frames from video")
172
+
173
+ except Exception as e:
174
+ self.logger.error(f"Error processing video: {str(e)}")
175
+ raise
 
 
 
 
 
 
 
 
 
176
 
177
  def query_video(self, query_text: str, k: int = 5) -> List[Dict]:
178
  """Query the video using natural language and return relevant frames."""
179
  self.logger.info(f"Processing query: {query_text}")
180
 
181
+ try:
182
+ inputs = self.clip_processor(text=[query_text], return_tensors="pt").to(self.device)
183
+ text_features = self.clip_model.get_text_features(**inputs)
184
+
185
+ distances, indices = self.frame_index.search(
186
+ text_features.cpu().detach().numpy(),
187
+ k
188
+ )
 
 
 
 
 
189
 
190
+ results = []
191
+ for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
192
+ frame_info = self.frame_data[idx].copy()
193
+ frame_info['relevance_score'] = float(1 / (1 + distance))
194
+ results.append(frame_info)
195
+
196
+ return results
197
+ except Exception as e:
198
+ self.logger.error(f"Error querying video: {str(e)}")
199
+ raise
200
 
201
  class VideoRAGApp:
202
  def __init__(self):