capradeepgujaran commited on
Commit
007d795
1 Parent(s): e25cab4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +159 -209
app.py CHANGED
@@ -1,11 +1,6 @@
1
  import cv2
2
  import numpy as np
3
- from transformers import (
4
- CLIPProcessor, CLIPModel,
5
- BlipProcessor, BlipForConditionalGeneration,
6
- Blip2Processor, Blip2ForConditionalGeneration,
7
- AutoProcessor, AutoModelForObjectDetection
8
- )
9
  import torch
10
  from PIL import Image
11
  import faiss
@@ -16,43 +11,37 @@ import tempfile
16
  import os
17
  import shutil
18
  from tqdm import tqdm
 
 
19
 
20
- class EnhancedVideoAnalyzer:
21
  def __init__(self):
22
  self.logger = self.setup_logger()
23
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  self.logger.info(f"Using device: {self.device}")
25
 
26
- # Initialize CLIP for general scene understanding
27
- self.logger.info("Loading CLIP model...")
28
- self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
29
- self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
30
-
31
- # Initialize BLIP-2 for detailed scene description
32
- self.logger.info("Loading BLIP-2 model...")
33
- self.blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
34
- self.blip2_model = Blip2ForConditionalGeneration.from_pretrained(
35
- "Salesforce/blip2-opt-2.7b",
36
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
  ).to(self.device)
38
 
39
- # Initialize Object Detection model
40
- self.logger.info("Loading object detection model...")
41
- self.obj_processor = AutoProcessor.from_pretrained("microsoft/table-transformer-detection")
42
- self.obj_model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection").to(self.device)
43
-
44
  self.frame_index = None
45
  self.frame_data = []
46
- self.target_size = (384, 384) # Increased size for better detail recognition
47
- self.batch_size = 4
48
 
49
- # Set all models to evaluation mode
50
- self.clip_model.eval()
51
- self.blip2_model.eval()
52
- self.obj_model.eval()
53
 
54
  def setup_logger(self) -> logging.Logger:
55
- logger = logging.getLogger('EnhancedVideoAnalyzer')
56
  if logger.handlers:
57
  logger.handlers.clear()
58
  logger.setLevel(logging.INFO)
@@ -62,256 +51,218 @@ class EnhancedVideoAnalyzer:
62
  logger.addHandler(handler)
63
  return logger
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @torch.no_grad()
66
  def analyze_frame(self, image: Image.Image) -> Dict:
67
  """Comprehensive frame analysis"""
68
  try:
69
- # 1. Generate detailed caption using BLIP-2
70
- inputs = self.blip2_processor(image, return_tensors="pt").to(self.device, torch.float16)
71
- caption = self.blip2_model.generate(**inputs, max_new_tokens=50)
72
- caption_text = self.blip2_processor.decode(caption[0], skip_special_tokens=True)
73
-
74
- # 2. Detect objects
75
- obj_inputs = self.obj_processor(images=image, return_tensors="pt").to(self.device)
76
- obj_outputs = self.obj_model(**obj_inputs)
77
-
78
- # Process object detection results
79
- target_sizes = torch.tensor([image.size[::-1]])
80
- results = self.obj_processor.post_process_object_detection(
81
- obj_outputs, threshold=0.5, target_sizes=target_sizes
82
- )[0]
83
-
84
- detected_objects = []
85
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
86
- detected_objects.append({
87
- "label": self.obj_processor.model.config.id2label[label.item()],
88
- "confidence": score.item()
89
- })
90
 
91
  return {
92
  "caption": caption_text,
93
- "objects": detected_objects
94
  }
95
-
96
  except Exception as e:
97
- self.logger.error(f"Error in frame analysis: {str(e)}")
98
- return {"caption": "Error analyzing frame", "objects": []}
99
 
100
- def extract_keyframes(self, video_path: str, max_frames: int = 15) -> List[Tuple[int, np.ndarray]]:
101
- """Extract key frames using scene detection"""
102
  cap = cv2.VideoCapture(video_path)
 
 
103
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
104
- fps = cap.get(cv2.CAP_PROP_FPS)
105
-
106
- # Calculate frame interval to get approximately max_frames
107
- frame_interval = max(1, total_frames // max_frames)
108
-
109
- frames = []
110
- frame_positions = []
111
- prev_gray = None
112
-
113
- with tqdm(total=total_frames, desc="Extracting frames") as pbar:
114
- while cap.isOpened() and len(frames) < max_frames:
115
  ret, frame = cap.read()
116
  if not ret:
117
  break
 
 
 
 
 
118
 
119
- # Convert to grayscale for scene detection
120
- gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
121
-
122
- if prev_gray is not None:
123
- # Calculate frame difference
124
- diff = cv2.absdiff(gray, prev_gray)
125
- mean_diff = np.mean(diff)
126
 
127
- # If significant change or first/last frame
128
- if mean_diff > 30 or len(frames) == 0:
129
- frames.append(frame)
130
- frame_positions.append(cap.get(cv2.CAP_PROP_POS_FRAMES))
131
-
132
- prev_gray = gray
133
- pbar.update(1)
134
-
 
 
 
 
135
  cap.release()
136
- return list(zip(frame_positions, frames))
137
 
138
- @torch.no_grad()
139
- def process_video(self, video_path: str) -> None:
140
- """Process video with comprehensive analysis"""
141
  self.logger.info(f"Processing video: {video_path}")
142
- self.frame_data = []
143
- features_list = []
144
 
145
  try:
146
- # Extract key frames
147
- keyframes = self.extract_keyframes(video_path)
148
- self.logger.info(f"Extracted {len(keyframes)} key frames")
149
-
150
- # Process frames with progress bar
151
- with tqdm(total=len(keyframes), desc="Analyzing frames") as pbar:
152
- for frame_pos, frame in keyframes:
153
- # Convert frame to PIL Image
154
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
155
- image = Image.fromarray(frame_rgb).resize(self.target_size, Image.LANCZOS)
156
-
157
- # Analyze frame
158
- analysis = self.analyze_frame(image)
159
-
160
- # Get CLIP features
161
- clip_inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
162
- image_features = self.clip_model.get_image_features(**clip_inputs)
163
-
164
- # Store results
165
- self.frame_data.append({
166
- 'frame_number': int(frame_pos),
167
- 'timestamp': frame_pos / 30.0, # Approximate timestamp
168
- 'caption': analysis['caption'],
169
- 'objects': analysis['objects']
170
- })
171
-
172
- features_list.append(image_features.cpu().numpy())
173
- pbar.update(1)
174
-
175
- # Create FAISS index
176
- if features_list:
177
- features_array = np.vstack(features_list)
178
- self.frame_index = faiss.IndexFlatL2(features_array.shape[1])
179
- self.frame_index.add(features_array)
180
 
181
- self.logger.info("Video processing completed successfully")
 
182
 
183
  except Exception as e:
184
- self.logger.error(f"Error processing video: {str(e)}")
185
- raise
186
 
187
  @torch.no_grad()
188
- def query_video(self, query_text: str, k: int = 4) -> List[Dict]:
189
- """Enhanced query processing"""
190
  try:
191
- # Process query with CLIP
192
- text_inputs = self.clip_processor(text=[query_text], return_tensors="pt").to(self.device)
193
- text_features = self.clip_model.get_text_features(**text_inputs)
194
-
195
- # Search for relevant frames
 
 
196
  distances, indices = self.frame_index.search(
197
- text_features.cpu().numpy(),
198
  k
199
  )
200
-
201
- # Prepare results with enhanced information
202
  results = []
203
  for distance, idx in zip(distances[0], indices[0]):
204
  frame_info = self.frame_data[idx].copy()
205
-
206
- # Add relevance score
207
- frame_info['relevance_score'] = float(1 / (1 + distance))
208
-
209
- # Add object summary
210
- obj_summary = ", ".join(obj["label"] for obj in frame_info['objects'][:3])
211
- if obj_summary:
212
- frame_info['object_summary'] = f"Objects detected: {obj_summary}"
213
-
214
  results.append(frame_info)
215
-
216
  return results
217
-
218
  except Exception as e:
219
- self.logger.error(f"Error querying video: {str(e)}")
220
- raise
221
 
222
- class VideoQAApp:
223
  def __init__(self):
224
- self.analyzer = EnhancedVideoAnalyzer()
225
- self.current_video_path = None
226
  self.processed = False
227
- self.temp_dir = tempfile.mkdtemp()
228
-
229
- def __del__(self):
230
- if hasattr(self, 'temp_dir') and os.path.exists(self.temp_dir):
231
- shutil.rmtree(self.temp_dir, ignore_errors=True)
232
 
233
  def process_video(self, video_file):
234
- """Process video with progress updates"""
235
  try:
236
  if video_file is None:
237
  return "Please upload a video first.", gr.Progress(0)
238
 
239
- video_path = video_file.name
240
- temp_video_path = os.path.join(self.temp_dir, "current_video.mp4")
241
- shutil.copy2(video_path, temp_video_path)
242
-
243
- self.current_video_path = temp_video_path
244
- self.analyzer.process_video(self.current_video_path)
245
- self.processed = True
246
-
247
- return "Video processed successfully! You can now ask questions about the video.", gr.Progress(100)
248
 
 
 
 
 
 
 
249
  except Exception as e:
250
  self.processed = False
251
- return f"Error processing video: {str(e)}", gr.Progress(0)
252
 
253
- def query_video(self, query_text):
254
- """Query video with comprehensive results"""
255
  if not self.processed:
256
  return None, "Please process a video first."
257
-
258
  try:
259
- results = self.analyzer.query_video(query_text)
 
 
 
 
 
 
260
  frames = []
261
  descriptions = []
262
-
263
- cap = cv2.VideoCapture(self.current_video_path)
264
-
265
  for result in results:
266
- frame_number = result['frame_number']
267
- cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
268
- ret, frame = cap.read()
269
 
270
- if ret:
271
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
272
- frames.append(Image.fromarray(frame_rgb))
273
-
274
- description = f"Timestamp: {result['timestamp']:.2f}s\n"
275
- description += f"Scene Description: {result['caption']}\n"
276
- if 'object_summary' in result:
277
- description += f"{result['object_summary']}\n"
278
- description += f"Relevance Score: {result['relevance_score']:.2f}"
279
- descriptions.append(description)
280
-
281
- cap.release()
282
-
283
- combined_description = "\n\nScene Analysis:\n\n"
284
  for i, desc in enumerate(descriptions, 1):
285
- combined_description += f"Frame {i}:\n{desc}\n\n"
286
-
287
- return frames, combined_description
288
-
289
  except Exception as e:
290
- return None, f"Error querying video: {str(e)}"
291
 
292
  def create_interface(self):
293
  """Create Gradio interface"""
294
- with gr.Blocks(title="Video Question Answering") as interface:
295
  gr.Markdown("# Advanced Video Question Answering")
296
  gr.Markdown("Upload a video and ask questions about any aspect of its content!")
297
 
298
  with gr.Row():
299
  video_input = gr.File(
300
- label="Upload Video (Recommended: 30 seconds to 5 minutes)",
301
  file_types=["video"],
302
  )
303
  process_button = gr.Button("Process Video")
304
 
305
- with gr.Row():
306
- status_output = gr.Textbox(
307
- label="Status",
308
- interactive=False
309
- )
310
- progress = gr.Progress()
311
 
312
  with gr.Row():
313
  query_input = gr.Textbox(
314
- label="Ask anything about the video",
315
  placeholder="What's happening in the video?"
316
  )
317
  query_button = gr.Button("Search")
@@ -319,7 +270,6 @@ class VideoQAApp:
319
  gallery = gr.Gallery(
320
  label="Retrieved Frames",
321
  show_label=True,
322
- elem_id="gallery",
323
  columns=[2],
324
  rows=[2],
325
  height="auto"
@@ -334,11 +284,11 @@ class VideoQAApp:
334
  process_button.click(
335
  fn=self.process_video,
336
  inputs=[video_input],
337
- outputs=[status_output, progress]
338
  )
339
 
340
  query_button.click(
341
- fn=self.query_video,
342
  inputs=[query_input],
343
  outputs=[gallery, descriptions]
344
  )
@@ -346,7 +296,7 @@ class VideoQAApp:
346
  return interface
347
 
348
  # Initialize and create the interface
349
- app = VideoQAApp()
350
  interface = app.create_interface()
351
 
352
  # Launch the app
 
1
  import cv2
2
  import numpy as np
3
+ from transformers import CLIPProcessor, CLIPModel, Blip2Processor, Blip2ForConditionalGeneration
 
 
 
 
 
4
  import torch
5
  from PIL import Image
6
  import faiss
 
11
  import os
12
  import shutil
13
  from tqdm import tqdm
14
+ from pathlib import Path
15
+ from moviepy.video.io.VideoFileClip import VideoFileClip
16
 
17
+ class VideoRAGSystem:
18
  def __init__(self):
19
  self.logger = self.setup_logger()
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  self.logger.info(f"Using device: {self.device}")
22
 
23
+ # Initialize models
24
+ self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
25
+ self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
+
27
+ self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
28
+ self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
29
+ "Salesforce/blip2-opt-2.7b",
 
 
 
30
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
31
  ).to(self.device)
32
 
33
+ # Vector store setup
 
 
 
 
34
  self.frame_index = None
35
  self.frame_data = []
36
+ self.target_size = (224, 224)
 
37
 
38
+ # Create directories for storing processed data
39
+ self.temp_dir = tempfile.mkdtemp()
40
+ self.frames_dir = os.path.join(self.temp_dir, "frames")
41
+ os.makedirs(self.frames_dir, exist_ok=True)
42
 
43
  def setup_logger(self) -> logging.Logger:
44
+ logger = logging.getLogger('VideoRAGSystem')
45
  if logger.handlers:
46
  logger.handlers.clear()
47
  logger.setLevel(logging.INFO)
 
51
  logger.addHandler(handler)
52
  return logger
53
 
54
+ def split_video(self, video_path: str, timestamp_ms: int, context_seconds: int = 3) -> str:
55
+ """Extract a clip around the specified timestamp"""
56
+ timestamp_sec = timestamp_ms / 1000
57
+ output_path = os.path.join(self.temp_dir, "clip.mp4")
58
+
59
+ with VideoFileClip(video_path) as video:
60
+ duration = video.duration
61
+ start_time = max(timestamp_sec - context_seconds, 0)
62
+ end_time = min(timestamp_sec + context_seconds, duration)
63
+ clip = video.subclip(start_time, end_time)
64
+ clip.write_videofile(output_path, audio_codec='aac')
65
+
66
+ return output_path
67
+
68
  @torch.no_grad()
69
  def analyze_frame(self, image: Image.Image) -> Dict:
70
  """Comprehensive frame analysis"""
71
  try:
72
+ # Generate caption
73
+ inputs = self.blip_processor(image, return_tensors="pt").to(self.device)
74
+ if self.device.type == "cuda":
75
+ inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
76
+ caption = self.blip_model.generate(**inputs, max_length=50)
77
+ caption_text = self.blip_processor.decode(caption[0], skip_special_tokens=True)
78
+
79
+ # Get visual features
80
+ clip_inputs = self.clip_processor(images=image, return_tensors="pt").to(self.device)
81
+ if self.device.type == "cuda":
82
+ clip_inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in clip_inputs.items()}
83
+ features = self.clip_model.get_image_features(**clip_inputs)
 
 
 
 
 
 
 
 
 
84
 
85
  return {
86
  "caption": caption_text,
87
+ "features": features.cpu().numpy()
88
  }
 
89
  except Exception as e:
90
+ self.logger.error(f"Frame analysis error: {str(e)}")
91
+ return None
92
 
93
+ def extract_keyframes(self, video_path: str, max_frames: int = 15) -> List[Dict]:
94
+ """Extract and analyze key frames"""
95
  cap = cv2.VideoCapture(video_path)
96
+ frames_info = []
97
+ frame_count = 0
98
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
99
+ interval = max(1, total_frames // max_frames)
100
+
101
+ with tqdm(total=max_frames, desc="Analyzing frames") as pbar:
102
+ while len(frames_info) < max_frames and cap.isOpened():
 
 
 
 
 
 
 
103
  ret, frame = cap.read()
104
  if not ret:
105
  break
106
+
107
+ if frame_count % interval == 0:
108
+ # Save frame
109
+ frame_path = os.path.join(self.frames_dir, f"frame_{frame_count}.jpg")
110
+ cv2.imwrite(frame_path, frame)
111
 
112
+ # Analyze frame
113
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
114
+ image = Image.fromarray(frame_rgb).resize(self.target_size, Image.LANCZOS)
115
+ analysis = self.analyze_frame(image)
 
 
 
116
 
117
+ if analysis is not None:
118
+ frames_info.append({
119
+ "frame_number": frame_count,
120
+ "timestamp": frame_count / cap.get(cv2.CAP_PROP_FPS),
121
+ "path": frame_path,
122
+ "caption": analysis["caption"],
123
+ "features": analysis["features"]
124
+ })
125
+ pbar.update(1)
126
+
127
+ frame_count += 1
128
+
129
  cap.release()
130
+ return frames_info
131
 
132
+ def process_video(self, video_path: str):
133
+ """Process video and build search index"""
 
134
  self.logger.info(f"Processing video: {video_path}")
 
 
135
 
136
  try:
137
+ # Extract and analyze frames
138
+ frames_info = self.extract_keyframes(video_path)
139
+ self.frame_data = frames_info
140
+
141
+ # Build FAISS index
142
+ if frames_info:
143
+ features = np.vstack([frame["features"] for frame in frames_info])
144
+ self.frame_index = faiss.IndexFlatL2(features.shape[1])
145
+ self.frame_index.add(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ self.logger.info(f"Processed {len(frames_info)} frames successfully")
148
+ return True
149
 
150
  except Exception as e:
151
+ self.logger.error(f"Video processing error: {str(e)}")
152
+ return False
153
 
154
  @torch.no_grad()
155
+ def search_frames(self, query: str, k: int = 4) -> List[Dict]:
156
+ """Search for relevant frames based on the query"""
157
  try:
158
+ # Process query
159
+ inputs = self.clip_processor(text=[query], return_tensors="pt").to(self.device)
160
+ if self.device.type == "cuda":
161
+ inputs = {k: v.half() if v.dtype == torch.float32 else v for k, v in inputs.items()}
162
+ query_features = self.clip_model.get_text_features(**inputs)
163
+
164
+ # Search
165
  distances, indices = self.frame_index.search(
166
+ query_features.cpu().numpy(),
167
  k
168
  )
169
+
170
+ # Prepare results
171
  results = []
172
  for distance, idx in zip(distances[0], indices[0]):
173
  frame_info = self.frame_data[idx].copy()
174
+ frame_info["relevance"] = float(1 / (1 + distance))
 
 
 
 
 
 
 
 
175
  results.append(frame_info)
176
+
177
  return results
178
+
179
  except Exception as e:
180
+ self.logger.error(f"Search error: {str(e)}")
181
+ return []
182
 
183
+ class VideoQAInterface:
184
  def __init__(self):
185
+ self.rag_system = VideoRAGSystem()
186
+ self.current_video = None
187
  self.processed = False
 
 
 
 
 
188
 
189
  def process_video(self, video_file):
190
+ """Handle video upload and processing"""
191
  try:
192
  if video_file is None:
193
  return "Please upload a video first.", gr.Progress(0)
194
 
195
+ self.current_video = video_file.name
196
+ success = self.rag_system.process_video(self.current_video)
 
 
 
 
 
 
 
197
 
198
+ if success:
199
+ self.processed = True
200
+ return "Video processed successfully! You can now ask questions.", gr.Progress(100)
201
+ else:
202
+ return "Error processing video. Please try again.", gr.Progress(0)
203
+
204
  except Exception as e:
205
  self.processed = False
206
+ return f"Error: {str(e)}", gr.Progress(0)
207
 
208
+ def answer_question(self, query):
209
+ """Handle question answering"""
210
  if not self.processed:
211
  return None, "Please process a video first."
212
+
213
  try:
214
+ # Search for relevant frames
215
+ results = self.rag_system.search_frames(query)
216
+
217
+ if not results:
218
+ return None, "No relevant frames found."
219
+
220
+ # Prepare output
221
  frames = []
222
  descriptions = []
223
+
 
 
224
  for result in results:
225
+ # Load frame
226
+ frame = Image.open(result["path"])
227
+ frames.append(frame)
228
 
229
+ # Prepare description
230
+ desc = f"Timestamp: {result['timestamp']:.2f}s\n"
231
+ desc += f"Scene Description: {result['caption']}\n"
232
+ desc += f"Relevance Score: {result['relevance']:.2f}"
233
+ descriptions.append(desc)
234
+
235
+ # Combine descriptions
236
+ combined_desc = "\n\nFrame Analysis:\n\n"
 
 
 
 
 
 
237
  for i, desc in enumerate(descriptions, 1):
238
+ combined_desc += f"Frame {i}:\n{desc}\n\n"
239
+
240
+ return frames, combined_desc
241
+
242
  except Exception as e:
243
+ return None, f"Error: {str(e)}"
244
 
245
  def create_interface(self):
246
  """Create Gradio interface"""
247
+ with gr.Blocks(title="Advanced Video Question Answering") as interface:
248
  gr.Markdown("# Advanced Video Question Answering")
249
  gr.Markdown("Upload a video and ask questions about any aspect of its content!")
250
 
251
  with gr.Row():
252
  video_input = gr.File(
253
+ label="Upload Video",
254
  file_types=["video"],
255
  )
256
  process_button = gr.Button("Process Video")
257
 
258
+ status_output = gr.Textbox(
259
+ label="Status",
260
+ interactive=False
261
+ )
 
 
262
 
263
  with gr.Row():
264
  query_input = gr.Textbox(
265
+ label="Ask about the video",
266
  placeholder="What's happening in the video?"
267
  )
268
  query_button = gr.Button("Search")
 
270
  gallery = gr.Gallery(
271
  label="Retrieved Frames",
272
  show_label=True,
 
273
  columns=[2],
274
  rows=[2],
275
  height="auto"
 
284
  process_button.click(
285
  fn=self.process_video,
286
  inputs=[video_input],
287
+ outputs=[status_output]
288
  )
289
 
290
  query_button.click(
291
+ fn=self.answer_question,
292
  inputs=[query_input],
293
  outputs=[gallery, descriptions]
294
  )
 
296
  return interface
297
 
298
  # Initialize and create the interface
299
+ app = VideoQAInterface()
300
  interface = app.create_interface()
301
 
302
  # Launch the app