File size: 1,972 Bytes
feaaa7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import streamlit as st

import torch
import sentence_transformers as sent
import datasets as ds

d = ds.load_dataset("wikipedia", "20220301.simple")
t = d["train"]
titles = t['title']

@st.cache(allow_output_mutation=True)
def load_model():
    return sent.SentenceTransformer("distiluse-base-multilingual-cased-v1")#"all-MiniLM-L6-v2")

@st.cache
def load_wikipedia_embeddings():
    return torch.load("titles-simple-0.pt", map_location=torch.device('cpu'))


st.title("Multilingual Semantic Search for Wikipedia Simple English")
st.markdown("""
Use semantic search to find related articles in Wikipedia Simple English: using a language model (sentence-transformers/distiluse-base-multilingual-cased-v1) we can find the closests titles from Wikipedia Simple English (wikipedia) queried in any of the model's trained languages: Arabic, Chinese, Dutch, English, French, German, Italian, Korean, Polish, Portuguese, Russian, Spanish, Turkish:


- colesterol
- développement humain
- Crise dos mísseis de Cuba


Also, "near natural language" queries are usually enough to bring up relevant results. Try:


- ¿cuál es el edificio más alto del mundo?
- comment préparer du poulet frit
- melhores películas de pixar


(note: search is done only on the article titles, not the content)
""")
model = load_model()
embeddings = load_wikipedia_embeddings()

#queries = ["Aristoteles", "Autismo", "Mental", "crecimiento poblacional"]
query = st.text_input("Query (es, fr, pt, ...)")

if query != "":
    queries = [query]
    queries_emb = model.encode(queries, convert_to_tensor=True)

    hits = sent.util.semantic_search(queries_emb, embeddings, top_k=5)

    for i,q in enumerate(queries):
        f"----\n{q}:\n"
        for hit in hits[i]:
            cid = hit['corpus_id']
            title = titles[cid]
            url = t[cid]['url']
            text = t[cid]['text'][:500] + "..."
            st.header(f"{title}")
            url
            text
            hit