import streamlit as st from timeit import default_timer as timer import torch import numpy as np import pandas as pd from huggingface_hub import hf_hub_download from MusicCaps.bart import BartCaptionModel from MusicCaps.audio_utils import load_audio, STR_CH_FIRST from diffusers import StableDiffusionPipeline, I2VGenXLPipeline from diffusers.utils import export_to_video, load_image import tensorflow as tf import torch physical_devices = tf.config.experimental.list_physical_devices('GPU') if len(physical_devices) > 0: tf.config.experimental.set_memory_growth(physical_devices[0], True) device = "cuda:0" if torch.cuda.is_available() else "cpu" @st.cache_resource def load_text_model(): model = BartCaptionModel(max_length = 128) pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu') state_dict = pretrained_object['state_dict'] model.load_state_dict(state_dict) if torch.cuda.is_available(): torch.cuda.set_device(device) model.eval() return model def get_audio(audio_path, duration=10, target_sr=16000): n_samples = int(duration * target_sr) audio, sr = load_audio( path= audio_path, ch_format= STR_CH_FIRST, sample_rate= target_sr, downmix_to_mono= True, ) if len(audio.shape) == 2: audio = audio.mean(0, False) # to mono input_size = int(n_samples) if audio.shape[-1] < input_size: # pad sequence pad = np.zeros(input_size) pad[: audio.shape[-1]] = audio audio = pad ceil = int(audio.shape[-1] // n_samples) audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32')) return audio def captioning(model,audio_path): audio_tensor = get_audio(audio_path = audio_path) # if device is not None: # audio_tensor = audio_tensor.to(device) with torch.no_grad(): output = model.generate( samples=audio_tensor, num_beams=5, ) inference = [] number_of_chunks = range(audio_tensor.shape[0]) for chunk, text in zip(number_of_chunks, output): output = "" time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]" output += f"{time}\n{text} \n \n" inference.append(output) return inference @st.cache_resource def load_image_model(): pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda") pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors") return pipeline # @st.cache_resource # def load_video_model(): # pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16") # return pipeline A2C_model = load_text_model() image_service = load_image_model() #video_model = load_video_model() if "audio_input" not in st.session_state: st.session_state.audio_input = None if "captions" not in st.session_state: st.session_state.captions = None if "image" not in st.session_state: st.session_state.image = None if "video" not in st.session_state: st.session_state.video = None st.title("Insage") st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input") if st.session_state.audio_input: audio_input = st.session_state.audio_input st.audio(audio_input) if st.button("Generate text prompt"): st.session_state.captions = captioning(A2C_model,audio_input)[0] captions = st.session_state.captions st.text(captions) if st.session_state.captions: if st.button("Generate Image from text prompt"): st.session_state.image = image_service(st.session_state.captions).images[0] # video = video_model( # prompt = st.session_state.captions, # image=st.session_state.image, # num_inference_steps=50 # ).frames[0] # st.session_state.video = video # export_to_video(video, "generated.mp4", fps=7) st.image(st.session_state.image)