Spaces:
Runtime error
Runtime error
from sentence_transformers import SentenceTransformer, util | |
from huggingface_hub import hf_hub_download | |
import os | |
import pickle | |
import pandas as pd | |
import gradio as gr | |
pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning | |
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True | |
pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb")) | |
songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="songs_new.csv")) | |
verses = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="verses.csv", use_auth_token=True)) | |
lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="lyrics_new.csv", use_auth_token=True)) | |
embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3') | |
song_ids = pickled["song_ids"] | |
corpus_embeddings = pickled["embeddings"] | |
def generate_playlist(prompt): | |
prompt_embedding = embedder.encode(prompt, convert_to_tensor=True) | |
hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20) | |
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score']) | |
verse_match = verses.iloc[hits['corpus_id']] | |
verse_match = verse_match.drop_duplicates(subset=["song_id"]) | |
song_match = songs[songs["song_id"].isin(verse_match["song_id"].values)] | |
song_match.song_id = pd.Categorical(song_match.song_id, categories=verse_match["song_id"].values) | |
song_match = song_match.sort_values("song_id") | |
song_match = song_match[0:9] # Only grab the top 9 | |
song_names = list(song_match["full_title"]) | |
song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg")) | |
images = [gr.Image.update(value=art, visible=True) for art in song_art] | |
return ( | |
gr.Radio.update(label="Songs", interactive=True, choices=song_names), | |
*images | |
) | |
def set_lyrics(full_title): | |
lyrics_text = lyrics[lyrics["song_id"].isin(songs[songs["full_title"] == full_title]["song_id"])]["text"].iloc[0] | |
return gr.Textbox.update(value=lyrics_text) | |
def set_example_prompt(example): | |
return gr.TextArea.update(value=example[0]) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# Playlist Generator 📻 🎵 | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
Enter a prompt and generate a playlist based on ✨semantic similarity✨ | |
This was built using Sentence Transformers and Gradio – [read more here!](#) | |
""") | |
song_prompt = gr.TextArea( | |
value="Running wild and free", | |
placeholder="Enter a song prompt, or choose an example" | |
) | |
example_prompts = gr.Dataset( | |
components=[song_prompt], | |
samples=[ | |
["I feel nostalgic for the past"], | |
["Running wild and free"], | |
["I'm deeply in love with someone I just met!"], | |
["My friends mean the world to me"], | |
["Sometimes I feel like no one understands"], | |
] | |
) | |
with gr.Column(): | |
fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽🎤").style(full_width=True) | |
with gr.Row(): | |
tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
with gr.Row(): | |
tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
with gr.Row(): | |
tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) | |
# Workaround because of the Gallery issues | |
tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9] | |
song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value") | |
with gr.Column(): | |
verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics") | |
fetch_songs.click( | |
fn=generate_playlist, | |
inputs=[song_prompt], | |
outputs=[song_option, *tiles], | |
) | |
example_prompts.click( | |
fn=set_example_prompt, | |
inputs=example_prompts, | |
outputs=example_prompts.components, | |
) | |
song_option.change( | |
fn=set_lyrics, | |
inputs=[song_option], | |
outputs=[verse] | |
) | |
demo.launch() | |