from sentence_transformers import SentenceTransformer, util from huggingface_hub import hf_hub_download import pickle import pandas as pd import gradio as gr pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb")) songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="songs_new.csv")) verses = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="verses.pkl"), "rb")) lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="lyrics_new.csv")) embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3') genius_ids = pickled["genius_ids"] corpus_embeddings = pickled["embeddings"] def generate_playlist(prompt): prompt_embedding = embedder.encode(prompt, convert_to_tensor=True) hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20) hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score']) verse_match = verses.iloc[hits['corpus_id']] verse_match = verse_match.drop_duplicates(subset=["genius_id"]) song_match = songs[songs["genius_id"].isin(verse_match["genius_id"].values)] song_match.genius_id = pd.Categorical(song_match.genius_id, categories=verse_match["genius_id"].values) song_match = song_match.sort_values("genius_id") song_match = song_match[0:9] # Only grab the top 9 song_names = list(song_match["full_title"]) song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg")) images = [gr.Image.update(value=art, visible=True) for art in song_art] return ( gr.Radio.update(label="Songs", interactive=True, choices=song_names), *images ) def set_lyrics(full_title): lyrics_text = lyrics[lyrics["genius_id"].isin(songs[songs["full_title"] == full_title]["genius_id"])]["text"].iloc[0] return gr.Textbox.update(value=lyrics_text) def set_example_prompt(example): return gr.TextArea.update(value=example[0]) demo = gr.Blocks() with demo: gr.Markdown( """ # Playlist Generator 📻 🎵 """) with gr.Row(): with gr.Column(): gr.Markdown( """ Enter a prompt and generate a playlist based on ✨semantic similarity✨ This was built using Sentence Transformers and Gradio – [read more here!](#) """) song_prompt = gr.TextArea( value="Running wild and free", placeholder="Enter a song prompt, or choose an example" ) example_prompts = gr.Dataset( components=[song_prompt], samples=[ ["I feel nostalgic for the past"], ["Running wild and free"], ["I'm deeply in love with someone I just met!"], ["My friends mean the world to me"], ["Sometimes I feel like no one understands"], ] ) with gr.Column(): fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽‍🎤").style(full_width=True) with gr.Row(): tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) with gr.Row(): tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) with gr.Row(): tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True) # Workaround because of the Gallery issues tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9] song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value") with gr.Column(): verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics") fetch_songs.click( fn=generate_playlist, inputs=[song_prompt], outputs=[song_option, *tiles], ) example_prompts.click( fn=set_example_prompt, inputs=example_prompts, outputs=example_prompts.components, ) song_option.change( fn=set_lyrics, inputs=[song_option], outputs=[verse] ) demo.launch()