|
import streamlit as st |
|
from timeit import default_timer as timer |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
from huggingface_hub import hf_hub_download |
|
from MusicCaps.bart import BartCaptionModel |
|
from MusicCaps.audio_utils import load_audio, STR_CH_FIRST |
|
from diffusers import StableDiffusionPipeline, I2VGenXLPipeline |
|
from diffusers.utils import export_to_video, load_image |
|
import tensorflow as tf |
|
import torch |
|
|
|
physical_devices = tf.config.experimental.list_physical_devices('GPU') |
|
if len(physical_devices) > 0: |
|
tf.config.experimental.set_memory_growth(physical_devices[0], True) |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
|
@st.cache_resource |
|
def load_text_model(): |
|
model = BartCaptionModel(max_length = 128) |
|
pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu') |
|
state_dict = pretrained_object['state_dict'] |
|
model.load_state_dict(state_dict) |
|
if torch.cuda.is_available(): |
|
torch.cuda.set_device(device) |
|
model.eval() |
|
return model |
|
|
|
def get_audio(audio_path, duration=10, target_sr=16000): |
|
n_samples = int(duration * target_sr) |
|
audio, sr = load_audio( |
|
path= audio_path, |
|
ch_format= STR_CH_FIRST, |
|
sample_rate= target_sr, |
|
downmix_to_mono= True, |
|
) |
|
if len(audio.shape) == 2: |
|
audio = audio.mean(0, False) |
|
input_size = int(n_samples) |
|
if audio.shape[-1] < input_size: |
|
pad = np.zeros(input_size) |
|
pad[: audio.shape[-1]] = audio |
|
audio = pad |
|
ceil = int(audio.shape[-1] // n_samples) |
|
audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) |
|
return audio |
|
|
|
def captioning(model,audio_path): |
|
audio_tensor = get_audio(audio_path = audio_path) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
samples=audio_tensor, |
|
num_beams=5, |
|
) |
|
inference = [] |
|
number_of_chunks = range(audio_tensor.shape[0]) |
|
for chunk, text in zip(number_of_chunks, output): |
|
output = "" |
|
time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]" |
|
output += f"{time}\n{text} \n \n" |
|
inference.append(output) |
|
return inference |
|
|
|
@st.cache_resource |
|
def load_image_model(): |
|
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda") |
|
pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") |
|
return pipeline |
|
|
|
@st.cache_resource |
|
def load_video_model(): |
|
pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16") |
|
return pipeline |
|
|
|
A2C_model = load_text_model() |
|
image_service = load_image_model() |
|
video_model = load_video_model() |
|
|
|
if "audio_input" not in st.session_state: |
|
st.session_state.audio_input = None |
|
|
|
if "captions" not in st.session_state: |
|
st.session_state.captions = None |
|
|
|
if "image" not in st.session_state: |
|
st.session_state.image = None |
|
|
|
if "video" not in st.session_state: |
|
st.session_state.video = None |
|
|
|
st.title("Testing MusicCaps") |
|
st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input") |
|
if st.session_state.audio_input: |
|
audio_input = st.session_state.audio_input |
|
st.audio(audio_input) |
|
if st.button("Generate text prompt"): |
|
st.session_state.captions = captioning(A2C_model,audio_input)[0] |
|
captions = st.session_state.captions |
|
st.text(captions) |
|
if st.session_state.captions: |
|
if st.button("Generate Image and video from text prompt"): |
|
st.session_state.image = image_service(captions).images[0] |
|
image = st.session_state.image |
|
video = video_model( |
|
prompt = captions, |
|
image=image, |
|
num_inference_steps=50 |
|
).frames[0] |
|
st.session_state.video = video |
|
export_to_video(video, "generated.mp4", fps=7) |
|
c1,c2 = st.columns([1,1]) |
|
with c1: |
|
st.image(image) |
|
with c2: |
|
st.video("generated.mp4") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|