Yhhxhfh commited on
Commit
6d8726f
1 Parent(s): 0ac46ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File
2
+ from pydantic import BaseModel
3
+ from aitextgen import aitextgen
4
+ from sklearn.datasets import fetch_20newsgroups
5
+ import nltk
6
+ import spacy
7
+ from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
8
+ from transformers import TTSModel, TTSProcessor
9
+ from audiocraft.models import MusicGen
10
+ from diffusers import StableDiffusionPipeline
11
+ import os
12
+ from typing import List
13
+
14
+ # Descargar nltk y cargar spacy
15
+ nltk.download('punkt')
16
+ nltk.download('stopwords')
17
+ spacy_model = spacy.load('en_core_web_sm')
18
+
19
+ app = FastAPI()
20
+
21
+ # Variables globales para almacenar los modelos
22
+ global aitextgen_model, hf_model, musicgen_model, image_generation_model, whisper_model, whisper_processor, tts_model, tts_processor, newsgroups
23
+ aitextgen_model = None
24
+ hf_model = None
25
+ musicgen_model = None
26
+ image_generation_model = None
27
+ whisper_model = None
28
+ whisper_processor = None
29
+ tts_model = None
30
+ tts_processor = None
31
+ newsgroups = None
32
+
33
+ # Funciones para cargar los modelos solo una vez
34
+ def load_aitextgen_model():
35
+ global aitextgen_model
36
+ if aitextgen_model is None:
37
+ aitextgen_model = aitextgen()
38
+ return aitextgen_model
39
+
40
+ def load_hf_model():
41
+ global hf_model
42
+ if hf_model is None:
43
+ hf_model = pipeline('text-generation', model='gpt2')
44
+ return hf_model
45
+
46
+ def load_musicgen_model():
47
+ global musicgen_model
48
+ if musicgen_model is None:
49
+ musicgen_model = MusicGen.get_pretrained('small')
50
+ return musicgen_model
51
+
52
+ def load_image_generation_model():
53
+ global image_generation_model
54
+ if image_generation_model is None:
55
+ image_generation_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
56
+ return image_generation_model
57
+
58
+ def load_whisper_model():
59
+ global whisper_model, whisper_processor
60
+ if whisper_model is None:
61
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
62
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
63
+ return whisper_model, whisper_processor
64
+
65
+ def load_tts_model():
66
+ global tts_model, tts_processor
67
+ if tts_model is None:
68
+ tts_model = TTSModel.from_pretrained("facebook/tts_transformer-tts")
69
+ tts_processor = TTSProcessor.from_pretrained("facebook/tts_transformer-tts")
70
+ return tts_model, tts_processor
71
+
72
+ def load_newsgroups():
73
+ global newsgroups
74
+ if newsgroups is None:
75
+ newsgroups = fetch_20newsgroups(subset='all').data
76
+ return newsgroups
77
+
78
+ class TextRequest(BaseModel):
79
+ prompt: str
80
+ max_length: int = 50
81
+
82
+ class MusicRequest(BaseModel):
83
+ prompt: str
84
+ duration: float = 10.0
85
+
86
+ class ImageRequest(BaseModel):
87
+ prompt: str
88
+ height: int = 512
89
+ width: int = 512
90
+
91
+ class TTSRequest(BaseModel):
92
+ text: str
93
+
94
+ @app.get("/")
95
+ def read_root():
96
+ return {"message": "Welcome to the Text, Music Generation, Image Generation, Whisper, and TTS API!"}
97
+
98
+ @app.post("/generate/")
99
+ def generate_text(request: TextRequest):
100
+ aitextgen_model = load_aitextgen_model()
101
+ generated_text = aitextgen_model.generate(prompt=request.prompt, max_length=request.max_length)
102
+ return {"generated_text": generated_text}
103
+
104
+ @app.post("/hf_generate/")
105
+ def hf_generate_text(request: TextRequest):
106
+ hf_model = load_hf_model()
107
+ generated_text = hf_model(request.prompt, max_length=request.max_length)
108
+ return {"generated_text": generated_text[0]['generated_text']}
109
+
110
+ @app.post("/music/")
111
+ def generate_music(request: MusicRequest):
112
+ musicgen_model = load_musicgen_model()
113
+ audio = musicgen_model.generate([request.prompt], durations=[request.duration])
114
+ musicgen_model.save_wav(audio[0], 'generated_music.wav')
115
+ return {"message": "Music generated successfully", "audio_file": "generated_music.wav"}
116
+
117
+ @app.post("/generate_image/")
118
+ def generate_image(request: ImageRequest):
119
+ image_generation_model = load_image_generation_model()
120
+ image = image_generation_model(request.prompt, height=request.height, width=request.width).images[0]
121
+ image_path = "generated_image.png"
122
+ image.save(image_path)
123
+ return {"message": "Image generated successfully", "image_file": "generated_image.png"}
124
+
125
+ @app.post("/transcribe/")
126
+ async def transcribe_audio(file: UploadFile = File(...)):
127
+ whisper_model, whisper_processor = load_whisper_model()
128
+ audio_input = await file.read()
129
+ audio_input = whisper_processor(audio_input, return_tensors="pt").input_features
130
+
131
+ with torch.no_grad():
132
+ predicted_ids = whisper_model.generate(audio_input)
133
+ transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
134
+
135
+ return {"transcription": transcription}
136
+
137
+ @app.post("/tts/")
138
+ def text_to_speech(request: TTSRequest):
139
+ tts_model, tts_processor = load_tts_model()
140
+ audio = tts_model.generate(request.text)
141
+ audio_path = "generated_speech.wav"
142
+ tts_model.save_wav(audio, audio_path)
143
+ return {"message": "Speech generated successfully", "audio_file": "generated_speech.wav"}
144
+
145
+ @app.get("/newsgroups/")
146
+ def get_newsgroups():
147
+ newsgroups_data = load_newsgroups()
148
+ return {"newsgroups": newsgroups_data[:5]}
149
+
150
+ @app.post("/process/")
151
+ def process_text(text: str):
152
+ tokens = nltk.word_tokenize(text)
153
+ doc = spacy_model(text)
154
+ return {
155
+ "tokens": tokens,
156
+ "entities": [(ent.text, ent.label_) for ent in doc.ents]
157
+ }