wanderJoy / app.py
HakimHa's picture
Update app.py
ebba648
import gradio as gr
from PIL import Image
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC, ViTFeatureExtractor, ViTForImageClassification
import soundfile as sf
import torch
import numpy as np
import time
# Initialize the transformers and the models
class_names = {
0: "al qarawiyyin",
1: "bab mansour el aleuj",
2: "chaouara tannery",
3: "hassan tower",
4: "jamae el fna",
5: "koutoubia mosque",
6: "madrasa ben youssef",
7: "majorel gardens",
8: "menara"
}
model_name_or_path = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
vit_model = ViTForImageClassification.from_pretrained('ohidaoui/monuments-morocco-v1')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
# Function to handle text input
def handle_text(text):
chat_output = chat({"question": text})
return chat_output["answer"]
# Function to handle image input
def get_class_name(class_idx):
return class_names[class_idx]
def handle_image(img):
img = np.array(img)
inputs = vit_feature_extractor(images=img, return_tensors="pt")
outputs = vit_model(**inputs)
predicted_class_idx = torch.argmax(outputs.logits, dim=1).item()
predicted_class_name = get_class_name(predicted_class_idx)
chat_output = chat({"question": "what is " + predicted_class_name})
return chat_output["answer"]
# Function to handle audio input
def handle_audio(audio):
audio = audio[1]
input_values = wav2vec2_processor(audio, sampling_rate=16_000, return_tensors="pt").input_values
input_values = input_values.to(torch.float32)
logits = wav2vec2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = wav2vec2_processor.decode(predicted_ids[0])
chat_output = chat({"question": transcriptions})
return chat_output["answer"]
# Main function to handle the inputs
def chatbot(history, text=None, img=None, audio=None):
text_output = handle_text(text) if text is not None else ''
img_output = handle_image(img) if img is not None else ''
audio_output = handle_audio(audio) if audio is not None else ''
outputs = [o for o in [text_output, img_output, audio_output] if o]
output = "\n".join(outputs)
history[-1][1] = output
for character in output:
history[-1][1] += character
time.sleep(0.05)
yield history
with gr.Blocks() as demo:
chat_interface = gr.Chatbot([], elem_id="chatbot", height=750)
with gr.Row():
with gr.Column(scale=0.85):
text_input = gr.Textbox(
show_label=False,
placeholder="Input Text here...",
container=False
)
with gr.Column(scale=0.15, min_width=0):
img_input = gr.Image()
audio_input = gr.Audio(source="microphone", label="Audio Input")
text_msg = text_input.submit(chatbot, [chat_interface, text_input], [chat_interface, text_input], queue=False)
img_msg = img_input.upload(chatbot, [chat_interface, img_input], [chat_interface, img_input], queue=False)
audio_msg = audio_input.upload(chatbot, [chat_interface, audio_input], [chat_interface, audio_input], queue=False)
demo.queue()
demo.launch(share=True)