Spaces:
Runtime error
Runtime error
File size: 5,142 Bytes
043d857 3d738ec 043d857 3d738ec 21c3f3d 3d738ec 21c3f3d 5b73951 043d857 3d738ec 043d857 3d738ec 043d857 21c3f3d 043d857 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
import os
import pickle
import pandas as pd
import gradio as gr
pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
corpus_embeddings = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="verse-embeddings.pkl"), "rb"))
songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="songs_new.csv"))
verses = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="verses.csv"))
lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="lyrics_new.csv", use_auth_token=auth_token))
embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
def generate_playlist(prompt):
prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
verse_match = verses.iloc[hits['corpus_id']]
verse_match = verse_match.drop_duplicates(subset=["song_id"])
song_match = songs[songs["song_id"].isin(verse_match["song_id"].values)]
song_match.song_id = pd.Categorical(song_match.song_id, categories=verse_match["song_id"].values)
song_match = song_match.sort_values("song_id")
song_match = song_match[0:9] # Only grab the top 9
song_names = list(song_match["full_title"])
song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
images = [gr.Image.update(value=art, visible=True) for art in song_art]
return (
gr.Radio.update(label="Songs", interactive=True, choices=song_names),
*images
)
def set_lyrics(full_title):
lyrics_text = lyrics[lyrics["song_id"].isin(songs[songs["full_title"] == full_title]["song_id"])]["text"].iloc[0]
return gr.Textbox.update(value=lyrics_text)
def set_example_prompt(example):
return gr.TextArea.update(value=example[0])
demo = gr.Blocks()
with demo:
gr.Markdown(
"""
# Playlist Generator 📻 🎵
""")
with gr.Row():
with gr.Column():
gr.Markdown(
"""
Enter a prompt and generate a playlist based on ✨semantic similarity✨
This was built using Sentence Transformers and Gradio – [see the blog](https://huggingface.co/blog/your-first-ml-project)!
""")
song_prompt = gr.TextArea(
value="Running wild and free",
placeholder="Enter a song prompt, or choose an example"
)
example_prompts = gr.Dataset(
components=[song_prompt],
samples=[
["I feel nostalgic for the past"],
["Running wild and free"],
["I'm deeply in love with someone I just met!"],
["My friends mean the world to me"],
["Sometimes I feel like no one understands"],
]
)
with gr.Column():
fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽🎤").style(full_width=True)
with gr.Row():
tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
with gr.Row():
tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
with gr.Row():
tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
# Workaround because of the Gallery issues
tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]
song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")
with gr.Column():
verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")
fetch_songs.click(
fn=generate_playlist,
inputs=[song_prompt],
outputs=[song_option, *tiles],
)
example_prompts.click(
fn=set_example_prompt,
inputs=example_prompts,
outputs=example_prompts.components,
)
song_option.change(
fn=set_lyrics,
inputs=[song_option],
outputs=[verse]
)
demo.launch()
|