File size: 5,211 Bytes
043d857
 
3d738ec
043d857
 
 
 
 
 
3d738ec
 
 
 
 
043d857
 
 
3d738ec
043d857
 
 
 
 
 
 
 
 
3d738ec
 
 
 
043d857
 
 
 
 
 
 
 
 
 
 
 
 
3d738ec
043d857
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from sentence_transformers import SentenceTransformer, util
from huggingface_hub import hf_hub_download
import os
import pickle
import pandas as pd
import gradio as gr

pd.options.mode.chained_assignment = None  # Turn off SettingWithCopyWarning

auth_token = os.environ.get("TOKEN_FROM_SECRET") or True
pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb"))
songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", repo_type="dataset", filename="songs_new.csv"))
verses = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="verses.csv", use_auth_token=True))
lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator-private", repo_type="dataset", filename="lyrics_new.csv", use_auth_token=True))

embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')

song_ids = pickled["song_ids"]
corpus_embeddings = pickled["embeddings"]


def generate_playlist(prompt):
    prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
    hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
    hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])

    verse_match = verses.iloc[hits['corpus_id']]
    verse_match = verse_match.drop_duplicates(subset=["song_id"])
    song_match = songs[songs["song_id"].isin(verse_match["song_id"].values)]
    song_match.song_id = pd.Categorical(song_match.song_id, categories=verse_match["song_id"].values)
    song_match = song_match.sort_values("song_id")
    song_match = song_match[0:9]  # Only grab the top 9

    song_names = list(song_match["full_title"])
    song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
    images = [gr.Image.update(value=art, visible=True) for art in song_art]

    return (
        gr.Radio.update(label="Songs", interactive=True, choices=song_names),
        *images
    )


def set_lyrics(full_title):
    lyrics_text = lyrics[lyrics["song_id"].isin(songs[songs["full_title"] == full_title]["song_id"])]["text"].iloc[0]
    return gr.Textbox.update(value=lyrics_text)


def set_example_prompt(example):
    return gr.TextArea.update(value=example[0])


demo = gr.Blocks()

with demo:
    gr.Markdown(
        """
        # Playlist Generator 📻 🎵
        """)

    with gr.Row():
        with gr.Column():
            gr.Markdown(
                """
                Enter a prompt and generate a playlist based on ✨semantic similarity✨
                This was built using Sentence Transformers and Gradio – [read more here!](#)
                """)

            song_prompt = gr.TextArea(
                value="Running wild and free",
                placeholder="Enter a song prompt, or choose an example"
            )
            example_prompts = gr.Dataset(
                components=[song_prompt],
                samples=[
                    ["I feel nostalgic for the past"],
                    ["Running wild and free"],
                    ["I'm deeply in love with someone I just met!"],
                    ["My friends mean the world to me"],
                    ["Sometimes I feel like no one understands"],
                ]
            )

        with gr.Column():
            fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽‍🎤").style(full_width=True)

            with gr.Row():
                tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
            with gr.Row():
                tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
            with gr.Row():
                tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
                tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)

            # Workaround because of the Gallery issues
            tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]

            song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")

        with gr.Column():
            verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")

    fetch_songs.click(
        fn=generate_playlist,
        inputs=[song_prompt],
        outputs=[song_option, *tiles],
    )

    example_prompts.click(
        fn=set_example_prompt,
        inputs=example_prompts,
        outputs=example_prompts.components,
    )

    song_option.change(
        fn=set_lyrics,
        inputs=[song_option],
        outputs=[verse]
    )

demo.launch()