File size: 5,361 Bytes
6d8726f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import FastAPI, HTTPException, UploadFile, File
from pydantic import BaseModel
from aitextgen import aitextgen
from sklearn.datasets import fetch_20newsgroups
import nltk
import spacy
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
from transformers import TTSModel, TTSProcessor
from audiocraft.models import MusicGen
from diffusers import StableDiffusionPipeline
import os
from typing import List

# Descargar nltk y cargar spacy
nltk.download('punkt')
nltk.download('stopwords')
spacy_model = spacy.load('en_core_web_sm')

app = FastAPI()

# Variables globales para almacenar los modelos
global aitextgen_model, hf_model, musicgen_model, image_generation_model, whisper_model, whisper_processor, tts_model, tts_processor, newsgroups
aitextgen_model = None
hf_model = None
musicgen_model = None
image_generation_model = None
whisper_model = None
whisper_processor = None
tts_model = None
tts_processor = None
newsgroups = None

# Funciones para cargar los modelos solo una vez
def load_aitextgen_model():
    global aitextgen_model
    if aitextgen_model is None:
        aitextgen_model = aitextgen()
    return aitextgen_model

def load_hf_model():
    global hf_model
    if hf_model is None:
        hf_model = pipeline('text-generation', model='gpt2')
    return hf_model

def load_musicgen_model():
    global musicgen_model
    if musicgen_model is None:
        musicgen_model = MusicGen.get_pretrained('small')
    return musicgen_model

def load_image_generation_model():
    global image_generation_model
    if image_generation_model is None:
        image_generation_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
    return image_generation_model

def load_whisper_model():
    global whisper_model, whisper_processor
    if whisper_model is None:
        whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
        whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
    return whisper_model, whisper_processor

def load_tts_model():
    global tts_model, tts_processor
    if tts_model is None:
        tts_model = TTSModel.from_pretrained("facebook/tts_transformer-tts")
        tts_processor = TTSProcessor.from_pretrained("facebook/tts_transformer-tts")
    return tts_model, tts_processor

def load_newsgroups():
    global newsgroups
    if newsgroups is None:
        newsgroups = fetch_20newsgroups(subset='all').data
    return newsgroups

class TextRequest(BaseModel):
    prompt: str
    max_length: int = 50

class MusicRequest(BaseModel):
    prompt: str
    duration: float = 10.0

class ImageRequest(BaseModel):
    prompt: str
    height: int = 512
    width: int = 512

class TTSRequest(BaseModel):
    text: str

@app.get("/")
def read_root():
    return {"message": "Welcome to the Text, Music Generation, Image Generation, Whisper, and TTS API!"}

@app.post("/generate/")
def generate_text(request: TextRequest):
    aitextgen_model = load_aitextgen_model()
    generated_text = aitextgen_model.generate(prompt=request.prompt, max_length=request.max_length)
    return {"generated_text": generated_text}

@app.post("/hf_generate/")
def hf_generate_text(request: TextRequest):
    hf_model = load_hf_model()
    generated_text = hf_model(request.prompt, max_length=request.max_length)
    return {"generated_text": generated_text[0]['generated_text']}

@app.post("/music/")
def generate_music(request: MusicRequest):
    musicgen_model = load_musicgen_model()
    audio = musicgen_model.generate([request.prompt], durations=[request.duration])
    musicgen_model.save_wav(audio[0], 'generated_music.wav')
    return {"message": "Music generated successfully", "audio_file": "generated_music.wav"}

@app.post("/generate_image/")
def generate_image(request: ImageRequest):
    image_generation_model = load_image_generation_model()
    image = image_generation_model(request.prompt, height=request.height, width=request.width).images[0]
    image_path = "generated_image.png"
    image.save(image_path)
    return {"message": "Image generated successfully", "image_file": "generated_image.png"}

@app.post("/transcribe/")
async def transcribe_audio(file: UploadFile = File(...)):
    whisper_model, whisper_processor = load_whisper_model()
    audio_input = await file.read()
    audio_input = whisper_processor(audio_input, return_tensors="pt").input_features

    with torch.no_grad():
        predicted_ids = whisper_model.generate(audio_input)
    transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

    return {"transcription": transcription}

@app.post("/tts/")
def text_to_speech(request: TTSRequest):
    tts_model, tts_processor = load_tts_model()
    audio = tts_model.generate(request.text)
    audio_path = "generated_speech.wav"
    tts_model.save_wav(audio, audio_path)
    return {"message": "Speech generated successfully", "audio_file": "generated_speech.wav"}

@app.get("/newsgroups/")
def get_newsgroups():
    newsgroups_data = load_newsgroups()
    return {"newsgroups": newsgroups_data[:5]}

@app.post("/process/")
def process_text(text: str):
    tokens = nltk.word_tokenize(text)
    doc = spacy_model(text)
    return {
        "tokens": tokens,
        "entities": [(ent.text, ent.label_) for ent in doc.ents]
    }