ThetaM2V / app.py
Monke64's picture
Changed title
47ba60e
raw
history blame contribute delete
No virus
4.18 kB
import streamlit as st
from timeit import default_timer as timer
import torch
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from MusicCaps.bart import BartCaptionModel
from MusicCaps.audio_utils import load_audio, STR_CH_FIRST
from diffusers import StableDiffusionPipeline, I2VGenXLPipeline
from diffusers.utils import export_to_video, load_image
import tensorflow as tf
import torch
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@st.cache_resource
def load_text_model():
model = BartCaptionModel(max_length = 128)
pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu')
state_dict = pretrained_object['state_dict']
model.load_state_dict(state_dict)
if torch.cuda.is_available():
torch.cuda.set_device(device)
model.eval()
return model
def get_audio(audio_path, duration=10, target_sr=16000):
n_samples = int(duration * target_sr)
audio, sr = load_audio(
path= audio_path,
ch_format= STR_CH_FIRST,
sample_rate= target_sr,
downmix_to_mono= True,
)
if len(audio.shape) == 2:
audio = audio.mean(0, False) # to mono
input_size = int(n_samples)
if audio.shape[-1] < input_size: # pad sequence
pad = np.zeros(input_size)
pad[: audio.shape[-1]] = audio
audio = pad
ceil = int(audio.shape[-1] // n_samples)
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
return audio
def captioning(model,audio_path):
audio_tensor = get_audio(audio_path = audio_path)
# if device is not None:
# audio_tensor = audio_tensor.to(device)
with torch.no_grad():
output = model.generate(
samples=audio_tensor,
num_beams=5,
)
inference = []
number_of_chunks = range(audio_tensor.shape[0])
for chunk, text in zip(number_of_chunks, output):
output = ""
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
output += f"{time}\n{text} \n \n"
inference.append(output)
return inference
@st.cache_resource
def load_image_model():
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
return pipeline
# @st.cache_resource
# def load_video_model():
# pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
# return pipeline
A2C_model = load_text_model()
image_service = load_image_model()
#video_model = load_video_model()
if "audio_input" not in st.session_state:
st.session_state.audio_input = None
if "captions" not in st.session_state:
st.session_state.captions = None
if "image" not in st.session_state:
st.session_state.image = None
if "video" not in st.session_state:
st.session_state.video = None
st.title("Insage")
st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input")
if st.session_state.audio_input:
audio_input = st.session_state.audio_input
st.audio(audio_input)
if st.button("Generate text prompt"):
st.session_state.captions = captioning(A2C_model,audio_input)[0]
captions = st.session_state.captions
st.text(captions)
if st.session_state.captions:
if st.button("Generate Image from text prompt"):
st.session_state.image = image_service(st.session_state.captions).images[0]
# video = video_model(
# prompt = st.session_state.captions,
# image=st.session_state.image,
# num_inference_steps=50
# ).frames[0]
# st.session_state.video = video
# export_to_video(video, "generated.mp4", fps=7)
st.image(st.session_state.image)