File size: 4,178 Bytes
728ab38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1be0da
728ab38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39aada5
 
728ab38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88962f9
 
 
 
728ab38
 
 
88962f9
728ab38
 
 
 
 
 
 
 
 
 
 
 
 
47ba60e
728ab38
 
 
 
 
 
 
 
 
47ba60e
4de6764
88962f9
 
 
 
 
 
 
47ba60e
 
728ab38
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import streamlit as st
from timeit import default_timer as timer
import torch
import numpy as np
import pandas as pd
from huggingface_hub import hf_hub_download
from MusicCaps.bart import BartCaptionModel
from MusicCaps.audio_utils import load_audio, STR_CH_FIRST
from diffusers import StableDiffusionPipeline, I2VGenXLPipeline
from diffusers.utils import export_to_video, load_image
import tensorflow as tf
import torch

physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

@st.cache_resource
def load_text_model():
    model = BartCaptionModel(max_length = 128)
    pretrained_object = torch.load('MusicCaps/transfer.pth', map_location='cpu')
    state_dict = pretrained_object['state_dict']
    model.load_state_dict(state_dict)
    if torch.cuda.is_available():
        torch.cuda.set_device(device)
    model.eval()
    return model

def get_audio(audio_path, duration=10, target_sr=16000):
    n_samples = int(duration * target_sr)
    audio, sr = load_audio(
        path= audio_path,
        ch_format= STR_CH_FIRST,
        sample_rate= target_sr,
        downmix_to_mono= True,
    )
    if len(audio.shape) == 2:
        audio = audio.mean(0, False)  # to mono
    input_size = int(n_samples)
    if audio.shape[-1] < input_size:  # pad sequence
        pad = np.zeros(input_size)
        pad[: audio.shape[-1]] = audio
        audio = pad
    ceil = int(audio.shape[-1] // n_samples)
    audio = torch.from_numpy(np.stack(np.split(audio[:ceil * n_samples], ceil)).astype('float32'))
    return audio

def captioning(model,audio_path):
    audio_tensor = get_audio(audio_path = audio_path)
    # if device is not None:
    #     audio_tensor = audio_tensor.to(device)
    with torch.no_grad():
        output = model.generate(
            samples=audio_tensor,
            num_beams=5,
        )
    inference = []
    number_of_chunks = range(audio_tensor.shape[0])
    for chunk, text in zip(number_of_chunks, output):
        output = ""
        time = f"[{chunk * 10}:00-{(chunk + 1) * 10}:00]"
        output += f"{time}\n{text} \n \n"
        inference.append(output)
    return inference

@st.cache_resource
def load_image_model():
     pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",torch_dtype=torch.float16).to("cuda")
     pipeline.load_lora_weights("LoRA dataset/Weights/pytorch_lora_weights.safetensors", weight_name="pytorch_lora_weights.safetensors")
     return pipeline

# @st.cache_resource
# def load_video_model():
#     pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
#     return pipeline

A2C_model = load_text_model()
image_service = load_image_model()
#video_model = load_video_model()

if "audio_input" not in st.session_state:
    st.session_state.audio_input = None

if "captions" not in st.session_state:
    st.session_state.captions = None

if "image" not in st.session_state:
    st.session_state.image = None

if "video" not in st.session_state:
    st.session_state.video = None

st.title("Insage")
st.session_state.audio_input = st.file_uploader("Insert Your Audio Clips Here",type = ["wav","mp3"], key = "Audio input")
if st.session_state.audio_input:
    audio_input = st.session_state.audio_input
    st.audio(audio_input)
    if st.button("Generate text prompt"):
        st.session_state.captions = captioning(A2C_model,audio_input)[0]
        captions = st.session_state.captions
        st.text(captions)
    if st.session_state.captions:
        if st.button("Generate Image from text prompt"):
            st.session_state.image = image_service(st.session_state.captions).images[0]
            # video = video_model(
            #     prompt = st.session_state.captions,
            #     image=st.session_state.image,
            #     num_inference_steps=50
            # ).frames[0]
            # st.session_state.video = video
            # export_to_video(video, "generated.mp4", fps=7)
            st.image(st.session_state.image)