aapot commited on
Commit
25bd6d3
1 Parent(s): 0df2f2d

Add gradio error handling for incorrect video urls

Browse files
Files changed (2) hide show
  1. app.py +15 -11
  2. utils/helper_funcs.py +5 -1
app.py CHANGED
@@ -40,17 +40,21 @@ def get_video_similarity(video1_url, video2_url):
40
  with_transcript = False
41
  else:
42
  with_transcript = True
43
- dataset = RRUMDataset(df, with_transcript=with_transcript, label_col=None,
44
- cross_encoder_model_name_or_path=cross_encoder_model_name_or_path)
45
- data_loader = DataLoader(dataset.test_dataset, shuffle=False,
46
- batch_size=1, num_workers=0, pin_memory=False)
47
-
48
- with torch.inference_mode():
49
- if with_transcript:
50
- pred = model_wt(next(iter(data_loader)))
51
- else:
52
- pred = model_nt(next(iter(data_loader)))
53
- pred = torch.special.expit(pred).squeeze().tolist()
 
 
 
 
54
  return f'YouTube videos are {pred:.0%} similar'
55
 
56
 
 
40
  with_transcript = False
41
  else:
42
  with_transcript = True
43
+ try:
44
+ dataset = RRUMDataset(df, with_transcript=with_transcript, label_col=None,
45
+ cross_encoder_model_name_or_path=cross_encoder_model_name_or_path)
46
+ data_loader = DataLoader(dataset.test_dataset, shuffle=False,
47
+ batch_size=1, num_workers=0, pin_memory=False)
48
+
49
+ with torch.inference_mode():
50
+ if with_transcript:
51
+ pred = model_wt(next(iter(data_loader)))
52
+ else:
53
+ pred = model_nt(next(iter(data_loader)))
54
+ pred = torch.special.expit(pred).squeeze().tolist()
55
+ except:
56
+ raise gr.Error(
57
+ f'There was error in getting a prediction from the model, please try again.')
58
  return f'YouTube videos are {pred:.0%} similar'
59
 
60
 
utils/helper_funcs.py CHANGED
@@ -2,6 +2,7 @@ import itertools
2
  import random
3
  import requests
4
  import pandas as pd
 
5
  from pytube import YouTube
6
  from youtube_transcript_api import YouTubeTranscriptApi
7
  from youtube_transcript_api.formatters import TextFormatter
@@ -62,7 +63,10 @@ def update_youtube_embedded_html(video_url, video_position):
62
 
63
 
64
  def get_youtube_video_data(url):
65
- video = YouTube(url)
 
 
 
66
  channel_id = video.channel_id
67
  video_title = video.title
68
  video_description = video.description
 
2
  import random
3
  import requests
4
  import pandas as pd
5
+ import gradio as gr
6
  from pytube import YouTube
7
  from youtube_transcript_api import YouTubeTranscriptApi
8
  from youtube_transcript_api.formatters import TextFormatter
 
63
 
64
 
65
  def get_youtube_video_data(url):
66
+ try:
67
+ video = YouTube(url)
68
+ except:
69
+ raise gr.Error(f'Could not find YouTube video with the URL {url}')
70
  channel_id = video.channel_id
71
  video_title = video.title
72
  video_description = video.description