Gxhhfhdhd / apgfp.py
Yhhxhfh's picture
Rename app.py to apgfp.py
a1f943d verified
raw
history blame
5.36 kB
from fastapi import FastAPI, HTTPException, UploadFile, File
from pydantic import BaseModel
from aitextgen import aitextgen
from sklearn.datasets import fetch_20newsgroups
import nltk
import spacy
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from transformers import TTSModel, TTSProcessor
from audiocraft.models import MusicGen
from diffusers import StableDiffusionPipeline
import os
from typing import List
# Descargar nltk y cargar spacy
nltk.download('punkt')
nltk.download('stopwords')
spacy_model = spacy.load('en_core_web_sm')
app = FastAPI()
# Variables globales para almacenar los modelos
global aitextgen_model, hf_model, musicgen_model, image_generation_model, whisper_model, whisper_processor, tts_model, tts_processor, newsgroups
aitextgen_model = None
hf_model = None
musicgen_model = None
image_generation_model = None
whisper_model = None
whisper_processor = None
tts_model = None
tts_processor = None
newsgroups = None
# Funciones para cargar los modelos solo una vez
def load_aitextgen_model():
global aitextgen_model
if aitextgen_model is None:
aitextgen_model = aitextgen()
return aitextgen_model
def load_hf_model():
global hf_model
if hf_model is None:
hf_model = pipeline('text-generation', model='gpt2')
return hf_model
def load_musicgen_model():
global musicgen_model
if musicgen_model is None:
musicgen_model = MusicGen.get_pretrained('small')
return musicgen_model
def load_image_generation_model():
global image_generation_model
if image_generation_model is None:
image_generation_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
return image_generation_model
def load_whisper_model():
global whisper_model, whisper_processor
if whisper_model is None:
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
return whisper_model, whisper_processor
def load_tts_model():
global tts_model, tts_processor
if tts_model is None:
tts_model = TTSModel.from_pretrained("facebook/tts_transformer-tts")
tts_processor = TTSProcessor.from_pretrained("facebook/tts_transformer-tts")
return tts_model, tts_processor
def load_newsgroups():
global newsgroups
if newsgroups is None:
newsgroups = fetch_20newsgroups(subset='all').data
return newsgroups
class TextRequest(BaseModel):
prompt: str
max_length: int = 50
class MusicRequest(BaseModel):
prompt: str
duration: float = 10.0
class ImageRequest(BaseModel):
prompt: str
height: int = 512
width: int = 512
class TTSRequest(BaseModel):
text: str
@app.get("/")
def read_root():
return {"message": "Welcome to the Text, Music Generation, Image Generation, Whisper, and TTS API!"}
@app.post("/generate/")
def generate_text(request: TextRequest):
aitextgen_model = load_aitextgen_model()
generated_text = aitextgen_model.generate(prompt=request.prompt, max_length=request.max_length)
return {"generated_text": generated_text}
@app.post("/hf_generate/")
def hf_generate_text(request: TextRequest):
hf_model = load_hf_model()
generated_text = hf_model(request.prompt, max_length=request.max_length)
return {"generated_text": generated_text[0]['generated_text']}
@app.post("/music/")
def generate_music(request: MusicRequest):
musicgen_model = load_musicgen_model()
audio = musicgen_model.generate([request.prompt], durations=[request.duration])
musicgen_model.save_wav(audio[0], 'generated_music.wav')
return {"message": "Music generated successfully", "audio_file": "generated_music.wav"}
@app.post("/generate_image/")
def generate_image(request: ImageRequest):
image_generation_model = load_image_generation_model()
image = image_generation_model(request.prompt, height=request.height, width=request.width).images[0]
image_path = "generated_image.png"
image.save(image_path)
return {"message": "Image generated successfully", "image_file": "generated_image.png"}
@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
whisper_model, whisper_processor = load_whisper_model()
audio_input = await file.read()
audio_input = whisper_processor(audio_input, return_tensors="pt").input_features
with torch.no_grad():
predicted_ids = whisper_model.generate(audio_input)
transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return {"transcription": transcription}
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
tts_model, tts_processor = load_tts_model()
audio = tts_model.generate(request.text)
audio_path = "generated_speech.wav"
tts_model.save_wav(audio, audio_path)
return {"message": "Speech generated successfully", "audio_file": "generated_speech.wav"}
@app.get("/newsgroups/")
def get_newsgroups():
newsgroups_data = load_newsgroups()
return {"newsgroups": newsgroups_data[:5]}
@app.post("/process/")
def process_text(text: str):
tokens = nltk.word_tokenize(text)
doc = spacy_model(text)
return {
"tokens": tokens,
"entities": [(ent.text, ent.label_) for ent in doc.ents]
}