from datasets import Dataset from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, TrainingArguments from youtube_transcript_api import YouTubeTranscriptApi from deepmultilingualpunctuation import PunctuationModel from googletrans import Translator import time import torch import re # import httpcore # setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') cp_aug = 'minnehwg/finetune-newwiki-summarization-ver-augmented2' def load_model(cp): tokenizer = AutoTokenizer.from_pretrained("VietAI/vit5-base") model = AutoModelForSeq2SeqLM.from_pretrained(cp) return tokenizer, model def summarize(text, model, tokenizer, num_beams=4, device='cpu'): model.to(device) inputs = tokenizer.encode(text, return_tensors="pt", max_length=1024, truncation=True, padding = True).to(device) with torch.no_grad(): summary_ids = model.generate(inputs, max_length=256, num_beams=num_beams) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) return summary def processed(text): processed_text = text.replace('\n', ' ') processed_text = processed_text.lower() return processed_text def get_subtitles(video_url): try: video_id = video_url.split("v=")[1] transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=['en']) subs = " ".join(entry['text'] for entry in transcript) return transcript, subs except Exception as e: return [], f"An error occurred: {e}" def restore_punctuation(text): model = PunctuationModel() result = model.restore_punctuation(text) return result def translate_long(text, language='vi'): translator = Translator() limit = 4700 chunks = [] current_chunk = '' sentences = re.split(r'(?= overlap_sentences: overlap = current_chunk[-overlap_sentences:] print(f"Overlapping sentences: {' '.join(overlap)}") chunks.append(' '.join(current_chunk)) current_chunk = current_chunk[-overlap_sentences:] + [sentence] current_word_count = sum(len(sent.split()) for sent in current_chunk) if current_chunk: if len(current_chunk) >= overlap_sentences: overlap = current_chunk[-overlap_sentences:] print(f"Overlapping sentences: {' '.join(overlap)}") chunks.append(' '.join(current_chunk)) return chunks def post_processing(text): sentences = re.split(r'(?<=[.!?])\s*', text) for i in range(len(sentences)): if sentences[i]: sentences[i] = sentences[i][0].upper() + sentences[i][1:] text = " ".join(sentences) return text def display(text): sentences = re.split(r'(?<=[.!?])\s*', text) unique_sentences = list(dict.fromkeys(sentences[:-1])) formatted_sentences = [f"• {sentence}" for sentence in unique_sentences] return formatted_sentences def pipeline(url): trans, sub = get_subtitles(url) sub = restore_punctuation(sub) vie_sub = translate_long(sub) vie_sub = processed(vie_sub) chunks = split_into_chunks(vie_sub, 700, 3) sum_para = [] for i in chunks: tmp = summarize(i, model_aug, tokenizer, num_beams=4) sum_para.append(tmp) sum = ''.join(sum_para) del sub, vie_sub, sum_para, chunks sum = post_processing(sum) re = display(sum) return re